feat(pt_expt): multi-task training support#5397
feat(pt_expt): multi-task training support#5397wanghan-iapcm wants to merge 15 commits intodeepmodeling:masterfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 40962b2939
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
40962b2 to
512eeb6
Compare
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds multi-task training and parameter-sharing: preprocess shared params, per-task model/loss/data/stat plumbing, probability-weighted env/fitting-stat merging, many descriptor/fitting Changes
Sequence DiagramsequenceDiagram
participant User as User/Config
participant Main as entrypoints.main
participant Pre as preprocess_shared_params
participant Trainer as Trainer
participant Wrapper as ModelWrapper
participant Model as Model (per-task)
participant Descr as Descriptor
participant Util as merge_env_stat
User->>Main: submit config with model_dict (multi-task)
Main->>Pre: preprocess_shared_params(model_config)
Pre-->>Main: return (model_config', shared_links)
Main->>Trainer: create Trainer(..., shared_links)
Trainer->>Wrapper: init ModelWrapper with per-task modules
Wrapper->>Wrapper: share_params(shared_links, model_key_prob_map)
Wrapper->>Model: invoke model[task].share_params(base, shared_level)
Model->>Descr: Descriptor.share_params(base_descr, shared_level)
alt resume == False
Descr->>Util: merge_env_stat(base_descr, self_descr, model_prob)
Util-->>Descr: merged stats and updated buffers
end
Trainer->>Trainer: sample task_key by model_prob
Trainer->>Wrapper: forward(batch, task_key)
Wrapper->>Model: model[task_key].forward(batch)
Model->>Trainer: outputs/loss -> Trainer (backprop & optimize)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 13
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/dpmodel/fitting/general_fitting.py (1)
673-694:⚠️ Potential issue | 🟠 MajorRestore explicit
fparam/aparamshape checks before reshaping.
xp.reshape()will accept any tensor with the right element count, so malformed inputs like(1, nf * nfp)or(nf, 1, nfp)are now silently reinterpreted instead of rejected._call_common()is the common path used by the other fitting classes, so this turns bad caller input into hard-to-debug numerical corruption rather than a clear validation error.Suggested fix
if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" + if fparam.shape != (nf, self.numb_fparam): + raise ValueError( + f"fparam has shape {fparam.shape}, expected {(nf, self.numb_fparam)}." + ) fparam = xp.reshape(fparam, (nf, self.numb_fparam)) fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1) ) @@ if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" + if aparam.shape != (nf, nloc, self.numb_aparam): + raise ValueError( + f"aparam has shape {aparam.shape}, expected {(nf, nloc, self.numb_aparam)}." + ) aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/fitting/general_fitting.py` around lines 673 - 694, Restore explicit shape validation in _call_common: before calling xp.reshape on fparam and aparam, assert their ndim and dimensions match the expected shapes (for fparam: ndim == 2 and shape == (nf, self.numb_fparam); for aparam: ndim == 3 and shape == (nf, nloc, self.numb_aparam) or adjust to allowed alternatives if intended). Use clear assertion messages referencing the variables (fparam, aparam, nf, nloc, self.numb_fparam, self.numb_aparam) so malformed inputs are rejected before xp.reshape, then proceed with the existing normalization and concatenation logic.deepmd/pt_expt/train/training.py (1)
864-868:⚠️ Potential issue | 🟠 MajorUse
self._unwrapped.model[...]in_compile_model()to access wrapped model in DDP mode.
enable_compileis DDP-unsafe in the current implementation. When_compile_model()runs in distributed mode,self.wrapperhas already been wrapped withDistributedDataParallel. After DDP wrapping, the originalModelWrapperis accessible only viaself.wrapper.module, notself.wrapper.model. The direct accessself.wrapper.model[task_key]at line 902 (and line 1004 when reassigning) will fail or return unexpected results.The codebase already provides
self._unwrappedproperty that correctly handles DDP unwrapping. Replaceself.wrapper.model[...]withself._unwrapped.model[...]at both locations.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 864 - 868, The _compile_model implementation is DDP-unsafe because it directly accesses self.wrapper.model[...] after DistributedDataParallel wrapping; update _compile_model (referenced as self._compile_model) to use self._unwrapped.model[...] instead of self.wrapper.model[...] for both the read (original access at the current access site) and the write/reassign (the later reassign at the second site) so the code correctly unwraps DDP and operates on the underlying ModelWrapper via the existing self._unwrapped property.
🧹 Nitpick comments (3)
deepmd/dpmodel/utils/env_mat_stat.py (1)
77-78: Potential KeyError iflink_statshas mismatched keys.The iteration assumes
base_statsandlink_statshave identical keys. Iflink_obj.statshas a different set of keys (e.g., different type counts or missing entries), accessinglink_stats[kk]will raise aKeyError.Consider adding defensive handling:
🛡️ Proposed defensive fix
for kk in base_stats: + if kk not in link_stats: + raise ValueError(f"Stats key '{kk}' missing in link_obj, cannot merge.") merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/utils/env_mat_stat.py` around lines 77 - 78, The loop in env_mat_stat.py assumes link_stats contains every key in base_stats, which can raise KeyError; update the loop in the function that builds merged_stats so it defensively fetches link_stats values (use link_stats.get(kk, default)) and choose a sensible default that matches base_stats' shape/type (e.g., 0 for scalars or numpy.zeros_like(base_stats[kk]) for arrays) before computing merged_stats[kk] = base_stats[kk] + link_val * model_prob; reference the variables merged_stats, base_stats, link_stats, model_prob to locate and change the code.deepmd/pt_expt/descriptor/hybrid.py (1)
33-39: Rename unused loop variabledesto_des.The loop variable is not used within the loop body. Per Ruff B007, prefix it with an underscore.
🔧 Suggested fix
- for ii, des in enumerate(self.descrpt_list): + for ii, _des in enumerate(self.descrpt_list): self.descrpt_list[ii].share_params(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/descriptor/hybrid.py` around lines 33 - 39, The loop in hybrid.py uses an unused loop variable named des; update the for-loop header to use _des instead (for ii, _des in enumerate(self.descrpt_list):) so it complies with Ruff B007, leaving the body that calls self.descrpt_list[ii].share_params(base_class.descrpt_list[ii], shared_level, model_prob=model_prob, resume=resume) unchanged; ensure only the loop variable name is changed and no other logic is modified.source/tests/pt_expt/descriptor/test_dpa3.py (1)
268-270: Consider using underscore prefix for unused unpacked variables.
nfandnlocfromself.nlist.shapeare not used in this test. Per Ruff RUF059, prefix them with underscores.🔧 Suggested fix
- nf, nloc, nnei = self.nlist.shape + _nf, _nloc, nnei = self.nlist.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/descriptor/test_dpa3.py` around lines 268 - 270, The unpacking of self.nlist.shape assigns nf and nloc which are unused; change the unpack to use underscore-prefixed names (e.g., _nf, _nloc, nnei) so only nnei remains a meaningful variable while satisfying RUF059; update the tuple assignment where self.nlist.shape is unpacked in the test (currently "nf, nloc, nnei = self.nlist.shape") to use "_nf, _nloc, nnei = self.nlist.shape".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 292-306: Wrap the trainer creation and run in a try/finally so
dist.destroy_process_group() is always called on exit: after calling
dist.init_process_group(...) when LOCAL_RANK is set, call get_trainer(...) and
trainer.run() inside a try block and in the finally check dist.is_available()
and dist.is_initialized() and call dist.destroy_process_group(); ensure
references to get_trainer, trainer.run, dist.init_process_group, and
dist.destroy_process_group are used so the teardown always runs even if
get_trainer() or trainer.run() raises.
- Around line 88-95: The single-task branch that creates stat directories uses
Path.mkdir() without parents=True which fails for nested paths; update the block
that checks stat_file_path existence (the code handling stat_file_path before
assigning DPPath) to call Path(stat_file_path).mkdir(parents=True,
exist_ok=True) for non-hdf5 paths (and preserve the existing h5py.File creation
for .h5/.hdf5), so nested directories are created the same way as in multi-task
mode before constructing DPPath(stat_file_path, "a").
In `@deepmd/pt_expt/fitting/invar_fitting.py`:
- Around line 45-77: The current logic can alias base_class buffers even when
base_class.get_param_stats().get("fparam") is empty, causing the first-shared
head to silently wipe later-populated stats; update the merge code in
invar_fitting.py (symbols: base_class, get_param_stats, fparam, _param_stats,
_buffers, fparam_avg, fparam_inv_std, self.get_param_stats, model_prob) so that
before aliasing you seed base_class's stats when base_stats is empty but
self_stats exists: if base_stats is falsy and self_stats truthy, set
base_class._param_stats["fparam"] = self_stats (or a copied/merged version) and
initialize base_class._buffers["fparam_avg"]/_inv_std from self_stats
averages/stds (compute with compute_avg/compute_std and copy_ into
base_class.fparam_avg and fparam_inv_std tensors) so that later aliasing of
self._buffers["fparam_avg"] = base_class._buffers["fparam_avg"] preserves the
actual populated stats; apply the same seeding pattern to the aparam branch (the
code at lines ~83-115) as well.
In `@deepmd/pt_expt/train/training.py`:
- Around line 638-643: The dict(zip(...)) calls silently truncate if
self.model_keys and self.model_prob differ; update both instances (the
model_key_prob_map creation in the self.wrapper.share_params call that zips
self.model_keys with self.model_prob, and the other zip around lines 844-850) to
use zip(..., strict=True) so a mismatch raises immediately; keep the dict(...)
wrapper and the same variable names (self.model_keys, self.model_prob) and
preserve surrounding parameters like resume, _finetune_update_stat, and
data_stat_protect.
- Around line 1226-1239: The loop over self.model_keys is advancing the training
iterator by calling get_data(is_train=True) for non-selected tasks, which
mutates the training stream and causes rank-0 drift; change it to sample from
validation or a non-advancing example instead (e.g., call
get_data(is_train=False, task_key=_key) or use a cached/peek sample API) and
remove any optimizer.zero_grad()/updates for these heads since they are only for
logging; keep using _unwrapped to compute the forward/loss but ensure you do not
consume training iterators or perform optimizer steps when populating
train_results (refer to model_keys, task_key, get_data, _unwrapped,
train_results).
- Around line 844-850: The resume path re-applies sharing via
self._unwrapped.share_params(shared_links, resume=True, ...) but omits the
original data_stat_protect so it falls back to the default 1e-2; fix by passing
the same protection value used on initial share_params (e.g.
data_stat_protect=self._data_stat_protect or the validated variable you stored
earlier) into this resume call so the call becomes
self._unwrapped.share_params(shared_links, resume=True, model_key_prob_map=...,
data_stat_protect=self._data_stat_protect); if that validated value isn't stored
on self, store it when share_params is first invoked (or read it from the
wrapper) and reuse it here.
In `@deepmd/pt_expt/train/wrapper.py`:
- Around line 118-121: The division that computes frac_prob using
model_key_prob_map[model_key_link] / model_key_prob_map[model_key_base] can
divide by zero when the chosen shared base head has probability 0; before
computing frac_prob (and similarly in the other occurrence around the 150-153
area) check model_key_prob_map[model_key_base] != 0 and if zero either select an
alternative non-zero base (search linked heads for the first with non-zero
probability) or raise a clear validation error indicating the base head has zero
probability; update the logic around frac_prob, model_key_link and
model_key_base to perform this guard and error reporting so the ratio is only
computed with a non-zero denominator.
In `@deepmd/pt_expt/utils/finetune.py`:
- Around line 107-111: The current assert checking model_branch should be
replaced with explicit CLI validation that raises a ValueError so it isn't
stripped by optimized runs; change the assert in finetune.py (the check
referencing model_branch and the 'finetune_head' message) to an if-statement
that raises ValueError with the same explanatory string when model_branch != ""
(e.g., if model_branch != "": raise ValueError("Multi-task fine-tuning does not
support command-line branches chosen! Please define the 'finetune_head' in each
model params!")).
In `@source/tests/pt_expt/fitting/test_fitting_stat.py`:
- Around line 148-150: The preflight skip helper _skip_if_no_data currently only
checks _PT_DATA and _PT_DATA_SINGLE but test_sharefitting_using_default_fparam
loads _PT_DATA_NO_FPARAM (data_1), causing FileNotFoundError on workers missing
that dataset; update the skip logic (in _skip_if_no_data or an analogous helper
used at lines ~218-221) to also check os.path.isdir(_PT_DATA_NO_FPARAM) and
raise unittest.SkipTest with a descriptive message if it's missing so the test
cleanly skips when data_1 is unavailable.
In `@source/tests/pt_expt/test_multitask.py`:
- Around line 1951-1952: The loop in test_multitask.py uses an unused variable
name n2 in the comprehension "for (n1, p1), (n2, p2) in
zip(mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True)"
which triggers Ruff; rename n2 to _ (and keep p2 if used, otherwise rename to _)
so the loop becomes "for (n1, p1), (_, p2) in zip(...)" to mark the variable as
intentionally unused and satisfy the linter; update any references accordingly
and run ruff check . before committing.
- Around line 945-947: The assignment is binding an unused second return value
from get_finetune_rules to finetune_links_true which ruff will flag; update the
call so only the used value is captured (e.g. assign the first return to
model_config_true and discard the second by using an underscore or by indexing
the returned tuple) — modify the statement that currently reads
"model_config_true, finetune_links_true = get_finetune_rules(...)" to
"model_config_true, _ = get_finetune_rules(...)" or similar to remove the unused
finetune_links_true binding.
In `@source/tests/pt_expt/test_training_ddp.py`:
- Around line 62-66: The current _find_free_port() helper is racy for parallel
DDP because it closes the socket before workers call init_process_group(),
allowing the port to be grabbed by another process; replace this rendezvous
approach by allocating a shared rendezvous that persists across workers (for
example, create a temporary file-based rendezvous URL "file://..." or
instantiate a pre-created torch.distributed.TCPStore and pass its address), then
update the test setup that calls init_process_group() to use that persistent
rendezvous (or TCPStore) instead of the ephemeral port returned by
_find_free_port(), ensuring all workers reuse the same rendezvous resource.
- Around line 612-1666: Wrap each mp.spawn call in a separate
multiprocessing.Process so you can enforce a 60s join timeout and fail the test
if the spawn does not complete; e.g. in test methods like
test_ddp_single_task_trains, test_ddp_multitask_trains,
test_ddp_gradient_equals_average, test_ddp_multitask_gradient,
TestDDPInitModel.test_ddp_init_model, TestDDPRestart.test_ddp_restart, all
finetune tests (TestDDPFinetune.test_ddp_finetune,
TestDDPFinetuneRandomFitting.test_ddp_finetune_random_fitting,
TestDDPFinetuneNewType.test_ddp_finetune_new_type,
TestDDPMultiTaskFinetune.test_ddp_multitask_finetune) and any other places
calling mp.spawn, create a subprocess that runs a lambda calling mp.spawn(...,
join=True), start it, call p.join(60), and if p.is_alive() after 60s
terminate/kill the process and call self.fail("DDP test timed out after 60s")
(also ensure you cleanup by terminating and joining the process). This enforces
the 60s timeout without changing the worker functions
(_worker_single_task_train, _worker_multitask_train, _worker_gradient_test,
_worker_multitask_gradient_test, _worker_check_resume, _worker_finetune,
_worker_multitask_finetune).
---
Outside diff comments:
In `@deepmd/dpmodel/fitting/general_fitting.py`:
- Around line 673-694: Restore explicit shape validation in _call_common: before
calling xp.reshape on fparam and aparam, assert their ndim and dimensions match
the expected shapes (for fparam: ndim == 2 and shape == (nf, self.numb_fparam);
for aparam: ndim == 3 and shape == (nf, nloc, self.numb_aparam) or adjust to
allowed alternatives if intended). Use clear assertion messages referencing the
variables (fparam, aparam, nf, nloc, self.numb_fparam, self.numb_aparam) so
malformed inputs are rejected before xp.reshape, then proceed with the existing
normalization and concatenation logic.
In `@deepmd/pt_expt/train/training.py`:
- Around line 864-868: The _compile_model implementation is DDP-unsafe because
it directly accesses self.wrapper.model[...] after DistributedDataParallel
wrapping; update _compile_model (referenced as self._compile_model) to use
self._unwrapped.model[...] instead of self.wrapper.model[...] for both the read
(original access at the current access site) and the write/reassign (the later
reassign at the second site) so the code correctly unwraps DDP and operates on
the underlying ModelWrapper via the existing self._unwrapped property.
---
Nitpick comments:
In `@deepmd/dpmodel/utils/env_mat_stat.py`:
- Around line 77-78: The loop in env_mat_stat.py assumes link_stats contains
every key in base_stats, which can raise KeyError; update the loop in the
function that builds merged_stats so it defensively fetches link_stats values
(use link_stats.get(kk, default)) and choose a sensible default that matches
base_stats' shape/type (e.g., 0 for scalars or numpy.zeros_like(base_stats[kk])
for arrays) before computing merged_stats[kk] = base_stats[kk] + link_val *
model_prob; reference the variables merged_stats, base_stats, link_stats,
model_prob to locate and change the code.
In `@deepmd/pt_expt/descriptor/hybrid.py`:
- Around line 33-39: The loop in hybrid.py uses an unused loop variable named
des; update the for-loop header to use _des instead (for ii, _des in
enumerate(self.descrpt_list):) so it complies with Ruff B007, leaving the body
that calls self.descrpt_list[ii].share_params(base_class.descrpt_list[ii],
shared_level, model_prob=model_prob, resume=resume) unchanged; ensure only the
loop variable name is changed and no other logic is modified.
In `@source/tests/pt_expt/descriptor/test_dpa3.py`:
- Around line 268-270: The unpacking of self.nlist.shape assigns nf and nloc
which are unused; change the unpack to use underscore-prefixed names (e.g., _nf,
_nloc, nnei) so only nnei remains a meaningful variable while satisfying RUF059;
update the tuple assignment where self.nlist.shape is unpacked in the test
(currently "nf, nloc, nnei = self.nlist.shape") to use "_nf, _nloc, nnei =
self.nlist.shape".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d5cdc4f7-6f63-429e-a313-64d43ae24ed7
📒 Files selected for processing (35)
deepmd/dpmodel/descriptor/repformers.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/dpmodel/utils/env_mat_stat.pydeepmd/pt/model/task/fitting.pydeepmd/pt_expt/descriptor/dpa1.pydeepmd/pt_expt/descriptor/dpa2.pydeepmd/pt_expt/descriptor/dpa3.pydeepmd/pt_expt/descriptor/hybrid.pydeepmd/pt_expt/descriptor/se_atten_v2.pydeepmd/pt_expt/descriptor/se_e2_a.pydeepmd/pt_expt/descriptor/se_r.pydeepmd/pt_expt/descriptor/se_t.pydeepmd/pt_expt/descriptor/se_t_tebd.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/fitting/ener_fitting.pydeepmd/pt_expt/fitting/invar_fitting.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/train/wrapper.pydeepmd/pt_expt/utils/finetune.pydeepmd/pt_expt/utils/multi_task.pysource/tests/pt/test_fitting_stat.pysource/tests/pt_expt/descriptor/test_descrpt_stat_merge.pysource/tests/pt_expt/descriptor/test_dpa1.pysource/tests/pt_expt/descriptor/test_dpa3.pysource/tests/pt_expt/descriptor/test_hybrid.pysource/tests/pt_expt/descriptor/test_se_atten_v2.pysource/tests/pt_expt/descriptor/test_se_r.pysource/tests/pt_expt/descriptor/test_se_t.pysource/tests/pt_expt/descriptor/test_se_t_tebd.pysource/tests/pt_expt/fitting/test_fitting_stat.pysource/tests/pt_expt/test_change_bias.pysource/tests/pt_expt/test_finetune.pysource/tests/pt_expt/test_multitask.pysource/tests/pt_expt/test_training.pysource/tests/pt_expt/test_training_ddp.py
💤 Files with no reviewable changes (1)
- deepmd/pt/model/task/fitting.py
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (9)
deepmd/pt_expt/entrypoints/main.py (2)
292-306:⚠️ Potential issue | 🟠 MajorAlways destroy the process group in a
finally.If
get_trainer()ortrainer.run()raises after Line 293,destroy_process_group()is skipped and the worker stays in a bad distributed state.Minimal fix
- if os.environ.get("LOCAL_RANK") is not None: - dist.init_process_group(backend="cuda:nccl,cpu:gloo") - - trainer = get_trainer( - config, - init_model, - restart, - finetune_model=finetune, - finetune_links=finetune_links, - shared_links=shared_links, - ) - trainer.run() - - if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() + try: + if os.environ.get("LOCAL_RANK") is not None: + dist.init_process_group(backend="cuda:nccl,cpu:gloo") + + trainer = get_trainer( + config, + init_model, + restart, + finetune_model=finetune, + finetune_links=finetune_links, + shared_links=shared_links, + ) + trainer.run() + finally: + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 292 - 306, Wrap the distributed init/destroy in a try/finally so the process group is always cleaned: call dist.init_process_group(...) (existing code) before creating the trainer, then surround get_trainer(...) and trainer.run() with try/finally and move the dist.destroy_process_group() into the finally block; keep the same guards using dist.is_available() and dist.is_initialized() before calling dist.destroy_process_group() to avoid errors if init failed.
88-95:⚠️ Potential issue | 🟠 MajorCreate nested stat directories in single-task mode too.
This branch still uses plain
mkdir(), so a nested path like./stat_files/model_1fails in single-task mode even though the multi-task branch already handles parent creation.Minimal fix
- Path(stat_file_path).mkdir() + Path(stat_file_path).mkdir(parents=True, exist_ok=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 88 - 95, The current branch that creates stat_file_path in single-task mode uses Path(stat_file_path).mkdir() without parents=True, so nested directories (e.g., "./stat_files/model_1") fail; update the creation logic where stat_file_path is checked (the block that tests Path(stat_file_path).exists() and then calls Path(stat_file_path).mkdir()) to create parent directories as needed by calling mkdir with parents=True and exist_ok=True (or equivalent), ensuring the DPPath(stat_file_path, "a") line still executes afterwards; keep the HDF5 file branch unchanged.source/tests/pt_expt/test_training_ddp.py (2)
62-66:⚠️ Potential issue | 🟠 MajorReplace free-port probing with a persistent rendezvous.
This helper closes the socket before the workers call
init_process_group(), so another process can grab the port in between and make the DDP suite flaky under parallel CI. Afile://rendezvous or a pre-createdTCPStoreis safer here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 62 - 66, The _find_free_port helper closes the socket before workers call init_process_group, allowing a race where another process can take the port; replace this free-port probing with a persistent rendezvous mechanism (e.g., use a file:// rendezvous URL or create and reuse a pre-created torch.distributed.TCPStore) so the port/address remains reserved across worker startup. Update the test setup that calls _find_free_port to instead create a stable rendezvous target (or a shared TCPStore) and pass that rendezvous information into init_process_group (or pass the initialized TCPStore into the processes) to avoid the transient socket race.
622-631:⚠️ Potential issue | 🟠 MajorAdd a hard 60s timeout around each spawned training scenario.
A stuck rendezvous or worker here will hang the suite indefinitely because every test calls
mp.spawn(..., join=True)with no outer timeout. Please wrap each spawned run in a watchdog process (or equivalent timeout mechanism) so the test fails fast instead of wedging CI. As per coding guidelines,**/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days.Also applies to: 662-671, 708-713, 815-820, 879-892, 938-951, 1053-1058, 1097-1102, 1195-1200, 1549-1561
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 622 - 631, Wrap the mp.spawn(...) call inside a watchdog multiprocessing.Process to enforce a 60s hard timeout: create a Process whose target executes mp.spawn(_worker_single_task_train, args=(2, port, self.data_dir, result_dict), nprocs=2, join=True), start it, call p.join(60), and if p.is_alive() then p.terminate(), p.join(), and raise an AssertionError (or fail the test) indicating a timeout; apply the same pattern for other tests that call mp.spawn so any stuck rendezvous/workers are killed and the test fails fast. Ensure you reference the same symbols (_worker_single_task_train and mp.spawn) and perform cleanup (terminate + join) before failing.source/tests/pt_expt/fitting/test_fitting_stat.py (1)
148-150:⚠️ Potential issue | 🟡 MinorExtend the data preflight to cover
data_1as well.
test_sharefitting_using_default_fparam()later reads_PT_DATA_NO_FPARAM, but the shared skip helper never checks that directory. On workers withoutpt/water/data/data_1, this turns into aFileNotFoundErrorinstead of a clean skip.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/fitting/test_fitting_stat.py` around lines 148 - 150, The preflight helper _skip_if_no_data currently only checks _PT_DATA; update it to also check the secondary dataset used in tests (_PT_DATA_NO_FPARAM or the data_1 directory) and raise unittest.SkipTest if either required directory is missing. In practice, modify _skip_if_no_data to verify os.path.isdir(_PT_DATA) and os.path.isdir(_PT_DATA_NO_FPARAM) (or the path for data_1) and produce a clear SkipTest message naming the missing directory so test_sharefitting_using_default_fparam() will be skipped instead of raising FileNotFoundError.deepmd/pt_expt/train/training.py (1)
1226-1239:⚠️ Potential issue | 🟠 MajorDon't advance training data just to log other tasks.
This rank-0-only logging path pulls extra training batches for the non-selected heads, so
disp_freqchanges the actual training stream and distributed runs drift further because only rank 0 consumes those iterators. Please switch these display-only metrics to validation data or a cached/non-advancing sample.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 1226 - 1239, The loop under self.model_keys currently pulls training batches for non-selected heads (self.get_data(is_train=True, task_key=_key)) which advances the training iterator on rank 0; change it to use validation/non-advancing samples instead (e.g., call self.get_data(is_train=False, task_key=_key) or fetch from a cached display sample like self._display_samples[_key] if available) and avoid calling optimizer.zero_grad on these display-only queries; keep using _unwrapped(...) to compute metrics but ensure the data source does not advance training iterators so train_results is populated without consuming training batches.deepmd/pt_expt/fitting/invar_fitting.py (1)
45-77:⚠️ Potential issue | 🟠 MajorSeed the base fitting stats before aliasing shared buffers.
If the first shared head has empty
fparam/aparamstats and a later head is the first one with real data, these branches skip the merge and still aliasbase_class's buffers into every branch. The resulting shared stats become model-order dependent and can drop the first populated stats entirely. Please initializebase_class._param_stats[...]and the corresponding avg/inv-std buffers fromself_statswhenbase_statsis empty, then continue chaining merges from that seeded state.Also applies to: 83-115
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/fitting/invar_fitting.py` around lines 45 - 77, When base_class has empty stats for "fparam" (base_stats == []), seed base_class._param_stats["fparam"] with a deep copy of self_stats and initialize base_class.fparam_avg and base_class.fparam_inv_std buffers from self_stats' computed averages/stds (using compute_avg/compute_std with protection and matching dtypes/device) before performing the merge logic that uses numb_fparam, model_prob and writes back merged stats; this ensures subsequent chaining uses the seeded state. Update the same pattern for the "aparam" branch as well (seed base_class._param_stats["aparam"] and base_class.aparam_avg/aparam_inv_std from self.get_param_stats()["aparam"] when base_stats empty) so aliasing via self._buffers[...] always reflects the first non-empty head.source/tests/pt_expt/test_multitask.py (2)
1951-1952:⚠️ Potential issue | 🟡 MinorUnused loop variable
n2should be marked intentional.Rename the unused variable to
_(or_n2) to satisfy Ruff.Suggested fix
- for (n1, p1), (n2, p2) in zip( + for (n1, p1), (_, p2) in zip( mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True ):As per coding guidelines, "Install linter and run
ruff check .before committing changes or the CI will fail".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_multitask.py` around lines 1951 - 1952, The loop over mt_desc.named_parameters() and mt_desc_2.named_parameters() declares an unused variable n2 which triggers the linter; rename n2 to _ (or _n2) in the for-loop header that iterates over (n1, p1), (n2, p2) returned from mt_desc.named_parameters() and mt_desc_2.named_parameters() so the unused name is explicit and Ruff will be satisfied while keeping p2 used as before.
945-947:⚠️ Potential issue | 🟡 MinorUnused
finetune_links_truestill triggers Ruff and will fail CI.This binding is unused in the test and should be explicitly discarded.
Suggested fix
- model_config_true, finetune_links_true = get_finetune_rules( + model_config_true, _ = get_finetune_rules( ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True )As per coding guidelines, "Install linter and run
ruff check .before committing changes or the CI will fail".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_multitask.py` around lines 945 - 947, The test binds an unused value finetune_links_true from get_finetune_rules, which triggers the linter; update the call site in test_multitask.py so the unused return is explicitly discarded (e.g., assign to _ or use Python's unpacking like model_config_true, _ = get_finetune_rules(...)) while preserving the existing model_config_true and change_model_params=True arguments; ensure you only change the left-hand side of the assignment where get_finetune_rules(...) is called.
🧹 Nitpick comments (3)
deepmd/pt_expt/descriptor/hybrid.py (1)
18-41: LGTM with minor nit: unused loop variable.The
share_paramsimplementation correctly delegates to each sub-descriptor. However, the loop variabledeson line 33 is extracted but never used—only the indexiiis needed.🔧 Proposed fix to rename unused variable
if shared_level == 0: - for ii, des in enumerate(self.descrpt_list): + for ii, _des in enumerate(self.descrpt_list): self.descrpt_list[ii].share_params(Alternatively, iterate directly with
range(len(self.descrpt_list)):if shared_level == 0: - for ii, des in enumerate(self.descrpt_list): + for ii in range(len(self.descrpt_list)): self.descrpt_list[ii].share_params(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/descriptor/hybrid.py` around lines 18 - 41, The for-loop in share_params uses an unused loop variable `des` (in the method share_params of the hybrid descriptor class), so replace the enumerate form with either a range-based loop (for ii in range(len(self.descrpt_list))) or use underscore for the unused value (for ii, _ in enumerate(self.descrpt_list)) to avoid the unused-variable warning; keep the body calling self.descrpt_list[ii].share_params(...) with the same arguments.deepmd/dpmodel/utils/env_mat_stat.py (1)
88-109: Consider handling case where neitherdavg/dstdnormean/stddevattributes exist.The function silently does nothing if neither attribute pair exists, which could lead to subtle bugs where stats are computed but buffers aren't updated. A warning or explicit handling may help surface configuration issues.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/utils/env_mat_stat.py` around lines 88 - 109, The code currently updates buffers on base_obj when it has davg/dstd or mean/stddev but silently skips if neither pair exists; add explicit handling for that case in the same function: after the existing if/elif on hasattr(base_obj, "davg") / hasattr(base_obj, "mean"), add an else branch that logs a warning (or raises an exception depending on desired strictness) referencing base_obj and the missing attributes, e.g., mentioning that neither davg/dstd nor mean/stddev were found and stats were not applied; ensure the logger or exception uses identifiers like base_obj, davg, dstd, mean, stddev, and respects set_davg_zero semantics if relevant.source/tests/pt_expt/descriptor/test_dpa3.py (1)
264-313: LGTM with minor nit: unused unpacked variables.The test correctly validates sharing semantics for DPA3. The unpacked variables
nfandnlocon line 268 are unused—onlynneiis needed.🔧 Proposed fix
- nf, nloc, nnei = self.nlist.shape + _, _, nnei = self.nlist.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/descriptor/test_dpa3.py` around lines 264 - 313, In test_share_params, remove the unused unpacked variables nf and nloc from the assignment of self.nlist.shape; only extract nnei (e.g., use a single-element assignment or ignore the first two values with placeholders) so the test no longer defines unused variables; update the line in test_share_params that currently reads "nf, nloc, nnei = self.nlist.shape" to only capture nnei (referencing the test_share_params function and the self.nlist.shape tuple).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 901-902: In _compile_model(), replace direct access to
self.wrapper.model[task_key] with the unwrapped module via self._unwrapped
(e.g., use self._unwrapped.model[task_key] or retrieve the task model from
self._unwrapped) so that code works when DistributedDataParallel has wrapped the
model; update all occurrences (including the similar block around the later
lines noted) to use self._unwrapped instead of self.wrapper to avoid
AttributeError when DDP is active.
In `@deepmd/pt_expt/utils/multi_task.py`:
- Around line 55-57: The code that parses key_in_dict (the block using if ":" in
key_in_dict and assigning shared_key/shared_level) should guard against multiple
colons and non-numeric levels: replace the naive split with a right-split
(rsplit(":", 1)) to obtain at most two parts, validate that you got exactly two
parts, and convert the second part to int inside a try/except (or with isdigit
check) raising or handling malformed inputs appropriately; update references to
shared_key and shared_level accordingly so downstream logic uses the validated
values.
In `@source/tests/pt_expt/test_multitask.py`:
- Around line 1171-1176: Docstring in the test block says it should only run for
"se_e2_a" but the guard uses descriptor.get("type") != "dpa3", causing the
mismatch; update the guard in the test methods (the conditional that checks
descriptor.get("type")) to check for "se_e2_a" instead of "dpa3" (or
alternatively update the docstrings to say "dpa3" if that was intended) and make
the same change in the other two occurrences mentioned so the docstring and the
guard in test_multitask.py are consistent.
In `@source/tests/pt_expt/test_training_ddp.py`:
- Around line 33-47: The module-level imports (get_trainer, get_finetune_rules,
preprocess_shared_params, normalize, update_deepmd_input) cause DEVICE to be
evaluated at import time; move the os.environ["DEVICE"] = "cpu" assignment so it
runs before importing any deepmd.pt_expt modules inside each spawned worker
function (i.e., set DEVICE at the top of each worker entrypoint before
performing imports) so child processes import deepmd.pt_expt with DEVICE=cpu;
update all worker entrypoints referenced (the ones that later import get_trainer
/ get_finetune_rules / preprocess_shared_params / normalize /
update_deepmd_input) to follow this order.
---
Duplicate comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 292-306: Wrap the distributed init/destroy in a try/finally so the
process group is always cleaned: call dist.init_process_group(...) (existing
code) before creating the trainer, then surround get_trainer(...) and
trainer.run() with try/finally and move the dist.destroy_process_group() into
the finally block; keep the same guards using dist.is_available() and
dist.is_initialized() before calling dist.destroy_process_group() to avoid
errors if init failed.
- Around line 88-95: The current branch that creates stat_file_path in
single-task mode uses Path(stat_file_path).mkdir() without parents=True, so
nested directories (e.g., "./stat_files/model_1") fail; update the creation
logic where stat_file_path is checked (the block that tests
Path(stat_file_path).exists() and then calls Path(stat_file_path).mkdir()) to
create parent directories as needed by calling mkdir with parents=True and
exist_ok=True (or equivalent), ensuring the DPPath(stat_file_path, "a") line
still executes afterwards; keep the HDF5 file branch unchanged.
In `@deepmd/pt_expt/fitting/invar_fitting.py`:
- Around line 45-77: When base_class has empty stats for "fparam" (base_stats ==
[]), seed base_class._param_stats["fparam"] with a deep copy of self_stats and
initialize base_class.fparam_avg and base_class.fparam_inv_std buffers from
self_stats' computed averages/stds (using compute_avg/compute_std with
protection and matching dtypes/device) before performing the merge logic that
uses numb_fparam, model_prob and writes back merged stats; this ensures
subsequent chaining uses the seeded state. Update the same pattern for the
"aparam" branch as well (seed base_class._param_stats["aparam"] and
base_class.aparam_avg/aparam_inv_std from self.get_param_stats()["aparam"] when
base_stats empty) so aliasing via self._buffers[...] always reflects the first
non-empty head.
In `@deepmd/pt_expt/train/training.py`:
- Around line 1226-1239: The loop under self.model_keys currently pulls training
batches for non-selected heads (self.get_data(is_train=True, task_key=_key))
which advances the training iterator on rank 0; change it to use
validation/non-advancing samples instead (e.g., call
self.get_data(is_train=False, task_key=_key) or fetch from a cached display
sample like self._display_samples[_key] if available) and avoid calling
optimizer.zero_grad on these display-only queries; keep using _unwrapped(...) to
compute metrics but ensure the data source does not advance training iterators
so train_results is populated without consuming training batches.
In `@source/tests/pt_expt/fitting/test_fitting_stat.py`:
- Around line 148-150: The preflight helper _skip_if_no_data currently only
checks _PT_DATA; update it to also check the secondary dataset used in tests
(_PT_DATA_NO_FPARAM or the data_1 directory) and raise unittest.SkipTest if
either required directory is missing. In practice, modify _skip_if_no_data to
verify os.path.isdir(_PT_DATA) and os.path.isdir(_PT_DATA_NO_FPARAM) (or the
path for data_1) and produce a clear SkipTest message naming the missing
directory so test_sharefitting_using_default_fparam() will be skipped instead of
raising FileNotFoundError.
In `@source/tests/pt_expt/test_multitask.py`:
- Around line 1951-1952: The loop over mt_desc.named_parameters() and
mt_desc_2.named_parameters() declares an unused variable n2 which triggers the
linter; rename n2 to _ (or _n2) in the for-loop header that iterates over (n1,
p1), (n2, p2) returned from mt_desc.named_parameters() and
mt_desc_2.named_parameters() so the unused name is explicit and Ruff will be
satisfied while keeping p2 used as before.
- Around line 945-947: The test binds an unused value finetune_links_true from
get_finetune_rules, which triggers the linter; update the call site in
test_multitask.py so the unused return is explicitly discarded (e.g., assign to
_ or use Python's unpacking like model_config_true, _ = get_finetune_rules(...))
while preserving the existing model_config_true and change_model_params=True
arguments; ensure you only change the left-hand side of the assignment where
get_finetune_rules(...) is called.
In `@source/tests/pt_expt/test_training_ddp.py`:
- Around line 62-66: The _find_free_port helper closes the socket before workers
call init_process_group, allowing a race where another process can take the
port; replace this free-port probing with a persistent rendezvous mechanism
(e.g., use a file:// rendezvous URL or create and reuse a pre-created
torch.distributed.TCPStore) so the port/address remains reserved across worker
startup. Update the test setup that calls _find_free_port to instead create a
stable rendezvous target (or a shared TCPStore) and pass that rendezvous
information into init_process_group (or pass the initialized TCPStore into the
processes) to avoid the transient socket race.
- Around line 622-631: Wrap the mp.spawn(...) call inside a watchdog
multiprocessing.Process to enforce a 60s hard timeout: create a Process whose
target executes mp.spawn(_worker_single_task_train, args=(2, port,
self.data_dir, result_dict), nprocs=2, join=True), start it, call p.join(60),
and if p.is_alive() then p.terminate(), p.join(), and raise an AssertionError
(or fail the test) indicating a timeout; apply the same pattern for other tests
that call mp.spawn so any stuck rendezvous/workers are killed and the test fails
fast. Ensure you reference the same symbols (_worker_single_task_train and
mp.spawn) and perform cleanup (terminate + join) before failing.
---
Nitpick comments:
In `@deepmd/dpmodel/utils/env_mat_stat.py`:
- Around line 88-109: The code currently updates buffers on base_obj when it has
davg/dstd or mean/stddev but silently skips if neither pair exists; add explicit
handling for that case in the same function: after the existing if/elif on
hasattr(base_obj, "davg") / hasattr(base_obj, "mean"), add an else branch that
logs a warning (or raises an exception depending on desired strictness)
referencing base_obj and the missing attributes, e.g., mentioning that neither
davg/dstd nor mean/stddev were found and stats were not applied; ensure the
logger or exception uses identifiers like base_obj, davg, dstd, mean, stddev,
and respects set_davg_zero semantics if relevant.
In `@deepmd/pt_expt/descriptor/hybrid.py`:
- Around line 18-41: The for-loop in share_params uses an unused loop variable
`des` (in the method share_params of the hybrid descriptor class), so replace
the enumerate form with either a range-based loop (for ii in
range(len(self.descrpt_list))) or use underscore for the unused value (for ii, _
in enumerate(self.descrpt_list)) to avoid the unused-variable warning; keep the
body calling self.descrpt_list[ii].share_params(...) with the same arguments.
In `@source/tests/pt_expt/descriptor/test_dpa3.py`:
- Around line 264-313: In test_share_params, remove the unused unpacked
variables nf and nloc from the assignment of self.nlist.shape; only extract nnei
(e.g., use a single-element assignment or ignore the first two values with
placeholders) so the test no longer defines unused variables; update the line in
test_share_params that currently reads "nf, nloc, nnei = self.nlist.shape" to
only capture nnei (referencing the test_share_params function and the
self.nlist.shape tuple).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2d5f2955-9a5d-4474-a90d-cbe31af58bd5
📒 Files selected for processing (35)
deepmd/dpmodel/descriptor/repformers.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/dpmodel/utils/env_mat_stat.pydeepmd/pt/model/task/fitting.pydeepmd/pt_expt/descriptor/dpa1.pydeepmd/pt_expt/descriptor/dpa2.pydeepmd/pt_expt/descriptor/dpa3.pydeepmd/pt_expt/descriptor/hybrid.pydeepmd/pt_expt/descriptor/se_atten_v2.pydeepmd/pt_expt/descriptor/se_e2_a.pydeepmd/pt_expt/descriptor/se_r.pydeepmd/pt_expt/descriptor/se_t.pydeepmd/pt_expt/descriptor/se_t_tebd.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/fitting/ener_fitting.pydeepmd/pt_expt/fitting/invar_fitting.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/train/wrapper.pydeepmd/pt_expt/utils/finetune.pydeepmd/pt_expt/utils/multi_task.pysource/tests/pt/test_fitting_stat.pysource/tests/pt_expt/descriptor/test_descrpt_stat_merge.pysource/tests/pt_expt/descriptor/test_dpa1.pysource/tests/pt_expt/descriptor/test_dpa3.pysource/tests/pt_expt/descriptor/test_hybrid.pysource/tests/pt_expt/descriptor/test_se_atten_v2.pysource/tests/pt_expt/descriptor/test_se_r.pysource/tests/pt_expt/descriptor/test_se_t.pysource/tests/pt_expt/descriptor/test_se_t_tebd.pysource/tests/pt_expt/fitting/test_fitting_stat.pysource/tests/pt_expt/test_change_bias.pysource/tests/pt_expt/test_finetune.pysource/tests/pt_expt/test_multitask.pysource/tests/pt_expt/test_training.pysource/tests/pt_expt/test_training_ddp.py
💤 Files with no reviewable changes (1)
- deepmd/pt/model/task/fitting.py
✅ Files skipped from review due to trivial changes (3)
- source/tests/pt/test_fitting_stat.py
- source/tests/pt_expt/descriptor/test_se_t.py
- deepmd/pt_expt/descriptor/dpa1.py
🚧 Files skipped from review as they are similar to previous changes (16)
- source/tests/pt_expt/test_change_bias.py
- deepmd/pt_expt/fitting/ener_fitting.py
- deepmd/pt_expt/descriptor/se_atten_v2.py
- deepmd/dpmodel/descriptor/repformers.py
- source/tests/pt_expt/descriptor/test_se_atten_v2.py
- source/tests/pt_expt/descriptor/test_hybrid.py
- source/tests/pt_expt/descriptor/test_se_t_tebd.py
- deepmd/pt_expt/descriptor/dpa3.py
- deepmd/pt_expt/descriptor/se_t_tebd.py
- deepmd/pt_expt/descriptor/se_t.py
- deepmd/pt_expt/descriptor/dpa2.py
- source/tests/pt_expt/test_finetune.py
- deepmd/dpmodel/fitting/general_fitting.py
- deepmd/pt_expt/utils/finetune.py
- deepmd/pt_expt/train/wrapper.py
- source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5397 +/- ##
==========================================
+ Coverage 80.35% 80.46% +0.10%
==========================================
Files 819 820 +1
Lines 85446 85990 +544
Branches 4140 4139 -1
==========================================
+ Hits 68661 69191 +530
- Misses 15509 15521 +12
- Partials 1276 1278 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
The reshape-first approach changed the error from a custom ValueError to a generic reshape error, breaking test_self_exception assertions.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/dpmodel/fitting/general_fitting.py (1)
676-705:⚠️ Potential issue | 🟠 MajorAdd explicit shape validation to prevent silent data misalignment in fparam/aparam reshape
Lines 677 and 700 accept any tensor with matching element count. This allows silently reinterpreting axis-swapped tensors—for example, a tensor with shape (numb_fparam, nf) would reshape to (nf, numb_fparam) with axes swapped, causing frame/atom index misalignment.
Proposed fix
if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" + expected_fparam_shape = (nf, self.numb_fparam) + if fparam.shape != expected_fparam_shape: + raise ValueError( + f"input fparam: expected shape {expected_fparam_shape}, got {fparam.shape}." + ) try: - fparam = xp.reshape(fparam, (nf, self.numb_fparam)) + fparam = xp.reshape(fparam, expected_fparam_shape) except (ValueError, RuntimeError) as e: raise ValueError( f"input fparam: cannot reshape {fparam.shape} " f"into ({nf}, {self.numb_fparam})." ) from e @@ if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" + expected_aparam_shape = (nf, nloc, self.numb_aparam) + if aparam.shape != expected_aparam_shape: + raise ValueError( + f"input aparam: expected shape {expected_aparam_shape}, got {aparam.shape}." + ) try: - aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + aparam = xp.reshape(aparam, expected_aparam_shape) except (ValueError, RuntimeError) as e: raise ValueError( f"input aparam: cannot reshape {aparam.shape} " f"into ({nf}, {nloc}, {self.numb_aparam})." ) from e🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/fitting/general_fitting.py` around lines 676 - 705, The fparam/aparam reshape code (variables fparam, aparam) currently only checks element count and can silently accept axis-swapped arrays; add explicit shape validation before calling xp.reshape: verify fparam.shape is one of {(nf, self.numb_fparam), (nf*self.numb_fparam,)} or fail with a clear ValueError including the actual shape and expected (nf, self.numb_fparam); similarly validate aparam.shape is one of {(nf, nloc, self.numb_aparam), (nf*nloc*self.numb_aparam,)} before reshape and raise a descriptive error if not; keep the existing normalization/tiling and concatenation (xx, xx_zeros) unchanged but only proceed after the explicit shape checks so misaligned axes are rejected rather than silently reshaped.
🧹 Nitpick comments (1)
deepmd/dpmodel/fitting/general_fitting.py (1)
413-415: Consider returning a defensive copy fromget_param_stats
GeneralFitting.get_param_stats()currently returns internal mutable state directly. Returning a shallow copy (or read-only mapping) would reduce accidental external mutation ofGeneralFitting._param_stats.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/fitting/general_fitting.py` around lines 413 - 415, get_param_stats currently returns the internal mutable attribute _param_stats directly which risks external mutation; change GeneralFitting.get_param_stats to return a defensive copy or read-only view (e.g., return dict(self._param_stats) for a shallow copy or wrap with types.MappingProxyType for an immutable mapping) so callers cannot accidentally modify the internal _param_stats; ensure the method still returns the same typing dict[str, list[StatItem]] and keep behavior when _param_stats is missing by copying the fallback {}.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@deepmd/dpmodel/fitting/general_fitting.py`:
- Around line 676-705: The fparam/aparam reshape code (variables fparam, aparam)
currently only checks element count and can silently accept axis-swapped arrays;
add explicit shape validation before calling xp.reshape: verify fparam.shape is
one of {(nf, self.numb_fparam), (nf*self.numb_fparam,)} or fail with a clear
ValueError including the actual shape and expected (nf, self.numb_fparam);
similarly validate aparam.shape is one of {(nf, nloc, self.numb_aparam),
(nf*nloc*self.numb_aparam,)} before reshape and raise a descriptive error if
not; keep the existing normalization/tiling and concatenation (xx, xx_zeros)
unchanged but only proceed after the explicit shape checks so misaligned axes
are rejected rather than silently reshaped.
---
Nitpick comments:
In `@deepmd/dpmodel/fitting/general_fitting.py`:
- Around line 413-415: get_param_stats currently returns the internal mutable
attribute _param_stats directly which risks external mutation; change
GeneralFitting.get_param_stats to return a defensive copy or read-only view
(e.g., return dict(self._param_stats) for a shallow copy or wrap with
types.MappingProxyType for an immutable mapping) so callers cannot accidentally
modify the internal _param_stats; ensure the method still returns the same
typing dict[str, list[StatItem]] and keep behavior when _param_stats is missing
by copying the fallback {}.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 833f6b73-5ddd-4021-a87c-844940a6fe51
📒 Files selected for processing (1)
deepmd/dpmodel/fitting/general_fitting.py
- Remove unused _min_nbor_dist variables in main.py - Remove unused _MULTITASK_JSON/_MULTITASK_SHAREFIT_JSON globals - Fix duplicate unittest import (use unittest.mock.patch directly) - Bind test socket to 127.0.0.1 instead of all interfaces - Remove redundant nframes assignment in _compile_model
Under DDP self.wrapper is DistributedDataParallel which has no .model attribute. Use .module to access the underlying ModelWrapper.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (9)
deepmd/pt_expt/train/training.py (3)
901-903:⚠️ Potential issue | 🔴 CriticalUse
_unwrapped.modelinside_compile_model().After DDP wrapping,
self.wrapperis aDistributedDataParallel, soself.wrapper.model[task_key]here will fail when compile is enabled in distributed mode. Read and write throughself._unwrapped.model[...]instead.Minimal fix
for task_key in self.model_keys: - model = self.wrapper.model[task_key] + model = self._unwrapped.model[task_key] ... - self.wrapper.model[task_key] = _CompiledModel( + self._unwrapped.model[task_key] = _CompiledModel( model, compiled_lower, max_nall, task_compile_opts )#!/bin/bash # Verify that DDP wrapping happens before _compile_model(), and that # _compile_model() still dereferences self.wrapper.model instead of self._unwrapped.model. sed -n '645,668p' deepmd/pt_expt/train/training.py printf '\n---\n' sed -n '901,1005p' deepmd/pt_expt/train/training.py printf '\n---\n' sed -n '1074,1079p' deepmd/pt_expt/train/training.pyExpected result: the first block shows
self.wrapperbeing replaced byDistributedDataParallel, the second block shows directself.wrapper.model[...]access inside_compile_model(), and the third block shows the existing_unwrappedhelper that should be used instead.Also applies to: 1003-1005
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 901 - 903, Change _compile_model to access model modules via self._unwrapped.model instead of self.wrapper.model: iterate over self.model_keys and use self._unwrapped.model[task_key] when reading or assigning models (replace occurrences of self.wrapper.model[...] in _compile_model). This ensures compatibility after DistributedDataParallel wrapping; search for _compile_model, self.wrapper, and self._unwrapped.model to update all direct wrapper.model dereferences (also apply same change around lines referenced for the second occurrence).
638-642:⚠️ Potential issue | 🟡 MinorMake both probability-map zips strict.
Both
dict(zip(...))calls still silently truncate ifself.model_keysandself.model_probever drift, and Ruff already flags both sites.Minimal fix
- model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ),- model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ),As per coding guidelines,
**/*.py: Install linter and runruff check .before committing changes or the CI will fail.Also applies to: 846-849
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 638 - 642, Summary: Replace silent-truncating dict(zip(...)) usages with strict length checks to prevent silent loss when self.model_keys and self.model_prob drift. Fix: before calling self.wrapper.share_params(...) where model_key_prob_map=dict(zip(self.model_keys, self.model_prob)) (and the other similar call later), validate that len(self.model_keys) == len(self.model_prob) and raise a clear ValueError if not; then construct model_key_prob_map using dict(zip(...)) safely. Reference symbols: self.wrapper.share_params, model_key_prob_map, self.model_keys, self.model_prob; apply the same validation and construction at the other occurrence flagged in the review.
1225-1235:⚠️ Potential issue | 🟠 MajorPer-task logging still advances the training iterators.
On rank 0, this logging path fetches
get_data(is_train=True, task_key=_key)for non-active heads just to populate metrics. That makesdisp_freqchange which batches those heads actually train on, and only rank 0 consumes those extra batches.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 1225 - 1235, The per-task logging loop is advancing training iterators because it calls get_data(is_train=True, task_key=_key) for non-active heads; change that so non-active heads do not consume training batches—either skip the get_data/_unwrapped call for _key != task_key or call get_data(is_train=False, task_key=_key) (or another non-consuming/validation accessor) so the training iterator is not advanced; update the loop around self.model_keys (the block using self.optimizer.zero_grad(), self.get_data(...), and self._unwrapped(...)) to only fetch training data for the active task_key and use a non-consuming path for logging metrics.source/tests/pt_expt/test_multitask.py (1)
939-941:⚠️ Potential issue | 🟡 MinorDrop the unused test bindings before Ruff does.
finetune_links_trueandn2are never read, so this file still trips the required lint step.Minimal fix
- model_config_true, finetune_links_true = get_finetune_rules( + model_config_true, _ = get_finetune_rules( ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True )- for (n1, p1), (n2, p2) in zip( + for (n1, p1), (_, p2) in zip( mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True ):As per coding guidelines,
**/*.py: Install linter and runruff check .before committing changes or the CI will fail.Also applies to: 1945-1947
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_multitask.py` around lines 939 - 941, The test creates unused bindings (finetune_links_true and n2) from get_finetune_rules and another unpack that Ruff flags; remove those unused variables by changing the unpack to only capture what is used (e.g., assign just model_config_true = get_finetune_rules(...)[0] or unpack as model_config_true, _ = get_finetune_rules(...)) and similarly drop or replace n2 with an underscore in the other unpack; update the lines that call get_finetune_rules and any other tuple unpacking to avoid creating unused names.source/tests/pt_expt/test_training_ddp.py (3)
622-631:⚠️ Potential issue | 🟠 MajorBound each spawned training case to 60 seconds.
All of these
mp.spawn(..., join=True)paths can hang indefinitely on rendezvous or worker failures. Please route them through a helper that enforces a 60s timeout and fails the test if the subprocess stays alive.As per coding guidelines,
**/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days.Also applies to: 666-671, 708-713, 815-820, 879-892, 938-951, 1053-1058, 1097-1102, 1195-1200, 1549-1560
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 622 - 631, Wrap the mp.spawn call in test_ddp_single_task_trains with a helper that enforces a 60s timeout and fails the test if any spawned subprocess remains alive; specifically replace the direct mp.spawn(...) invocation that launches _worker_single_task_train with a timeout-aware runner (e.g., run_spawn_with_timeout(spawn_fn=_worker_single_task_train, nprocs=2, args=(2, port, self.data_dir, result_dict), timeout=60)) so rendezvous or worker hangs are terminated and the test asserts failure. Apply the same pattern to the other mp.spawn sites noted (the ranges referenced in the comment) so every training test uses the common 60s timeout helper.
33-47:⚠️ Potential issue | 🔴 Critical
DEVICE=cpuis set too late for spawned workers.With
mp.spawn, each child imports this module before executing_worker_*, so these module-leveldeepmd.pt_exptimports can resolvedeepmd.pt_expt.utils.env.DEVICEbefore the worker sets the environment. That makes the test behavior depend on the parent process environment instead of staying CPU-only.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 33 - 47, The test sets DEVICE too late so module-level imports (e.g., importing get_trainer, get_finetune_rules, preprocess_shared_params, normalize, update_deepmd_input) cause deepmd.pt_expt to read env.DEVICE before spawned workers set it; to fix, ensure DEVICE is set to "cpu" before those imports by either moving os.environ["DEVICE"]="cpu" to the top of the module (before any deepmd.pt_expt imports) or by deferring these imports into the worker entrypoints (the worker functions named like _worker_*) so the environment is configured first.
62-66:⚠️ Potential issue | 🟠 MajorThe free-port probe is still racy.
_find_free_port()closes the socket before the workers callinit_process_group(), so another process can claim the port in between and make this suite flaky under parallel CI.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 62 - 66, The current _find_free_port() closes the probe socket before workers call init_process_group(), which allows another process to grab the port; change _find_free_port to bind and listen but not close the socket (e.g., set SO_REUSEADDR, bind(("127.0.0.1", 0)), listen()) and return the open socket and its port number so the test can hold the socket open until after init_process_group() completes; update callers to keep the returned socket alive (and only close it after init_process_group() on all workers) to eliminate the race.deepmd/pt_expt/entrypoints/main.py (2)
290-304:⚠️ Potential issue | 🟠 MajorAlways destroy the process group in a
finally.If
get_trainer()ortrainer.run()raises, teardown is skipped and the process group stays initialized in a bad state.Minimal fix
- if os.environ.get("LOCAL_RANK") is not None: - dist.init_process_group(backend="cuda:nccl,cpu:gloo") - - trainer = get_trainer( - config, - init_model, - restart, - finetune_model=finetune, - finetune_links=finetune_links, - shared_links=shared_links, - ) - trainer.run() - - if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() + try: + if os.environ.get("LOCAL_RANK") is not None: + dist.init_process_group(backend="cuda:nccl,cpu:gloo") + + trainer = get_trainer( + config, + init_model, + restart, + finetune_model=finetune, + finetune_links=finetune_links, + shared_links=shared_links, + ) + trainer.run() + finally: + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 290 - 304, Wrap the code that calls dist.init_process_group, get_trainer, and trainer.run in a try/finally so the process group is always torn down; specifically, after calling dist.init_process_group (when os.environ.get("LOCAL_RANK") is not None) call get_trainer(...) and trainer.run() inside a try block and in the finally call dist.destroy_process_group() only if dist.is_available() and dist.is_initialized(); update the block around dist.init_process_group, get_trainer, and trainer.run to ensure cleanup even if get_trainer() or trainer.run() raises.
87-95:⚠️ Potential issue | 🟠 MajorCreate parent stat directories in the single-task path too.
This branch still uses plain
mkdir(), so nested targets like./stat_files/model_1work in multi-task mode and fail here.Minimal fix
- Path(stat_file_path).mkdir() + Path(stat_file_path).mkdir(parents=True, exist_ok=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 87 - 95, The code fails for nested stat paths because it calls Path(...).mkdir() without parents=True; update the branch that handles stat_file_path: when stat_file_path endswith h5/hdf5, ensure its parent directory exists (Path(stat_file_path).parent.mkdir(parents=True, exist_ok=True)) before creating the h5 file, and when treating stat_file_path as a directory call Path(stat_file_path).mkdir(parents=True, exist_ok=True); keep references to stat_file_path and DPPath in the same block so the rest of the logic is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 531-540: The loop over self.model_keys assumes validation_data is
a mapping and indexes validation_data[model_key] unconditionally; when the
caller passes validation_data=None this raises. Fix by checking whether
validation_data is None before accessing validation_data[model_key]: compute
data_requirement via self.loss[model_key] and
get_additional_data_requirement(self.model[model_key]), call
training_data[model_key].add_data_requirements(...) as before, and only if
validation_data is not None call
validation_data[model_key].add_data_requirements(...) and set
self.validation_data[model_key] = validation_data[model_key]; otherwise set
self.validation_data[model_key] = None (or skip assignment) so the code handles
a global None validation_data safely.
---
Duplicate comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 290-304: Wrap the code that calls dist.init_process_group,
get_trainer, and trainer.run in a try/finally so the process group is always
torn down; specifically, after calling dist.init_process_group (when
os.environ.get("LOCAL_RANK") is not None) call get_trainer(...) and
trainer.run() inside a try block and in the finally call
dist.destroy_process_group() only if dist.is_available() and
dist.is_initialized(); update the block around dist.init_process_group,
get_trainer, and trainer.run to ensure cleanup even if get_trainer() or
trainer.run() raises.
- Around line 87-95: The code fails for nested stat paths because it calls
Path(...).mkdir() without parents=True; update the branch that handles
stat_file_path: when stat_file_path endswith h5/hdf5, ensure its parent
directory exists (Path(stat_file_path).parent.mkdir(parents=True,
exist_ok=True)) before creating the h5 file, and when treating stat_file_path as
a directory call Path(stat_file_path).mkdir(parents=True, exist_ok=True); keep
references to stat_file_path and DPPath in the same block so the rest of the
logic is unchanged.
In `@deepmd/pt_expt/train/training.py`:
- Around line 901-903: Change _compile_model to access model modules via
self._unwrapped.model instead of self.wrapper.model: iterate over
self.model_keys and use self._unwrapped.model[task_key] when reading or
assigning models (replace occurrences of self.wrapper.model[...] in
_compile_model). This ensures compatibility after DistributedDataParallel
wrapping; search for _compile_model, self.wrapper, and self._unwrapped.model to
update all direct wrapper.model dereferences (also apply same change around
lines referenced for the second occurrence).
- Around line 638-642: Summary: Replace silent-truncating dict(zip(...)) usages
with strict length checks to prevent silent loss when self.model_keys and
self.model_prob drift. Fix: before calling self.wrapper.share_params(...) where
model_key_prob_map=dict(zip(self.model_keys, self.model_prob)) (and the other
similar call later), validate that len(self.model_keys) == len(self.model_prob)
and raise a clear ValueError if not; then construct model_key_prob_map using
dict(zip(...)) safely. Reference symbols: self.wrapper.share_params,
model_key_prob_map, self.model_keys, self.model_prob; apply the same validation
and construction at the other occurrence flagged in the review.
- Around line 1225-1235: The per-task logging loop is advancing training
iterators because it calls get_data(is_train=True, task_key=_key) for non-active
heads; change that so non-active heads do not consume training batches—either
skip the get_data/_unwrapped call for _key != task_key or call
get_data(is_train=False, task_key=_key) (or another non-consuming/validation
accessor) so the training iterator is not advanced; update the loop around
self.model_keys (the block using self.optimizer.zero_grad(), self.get_data(...),
and self._unwrapped(...)) to only fetch training data for the active task_key
and use a non-consuming path for logging metrics.
In `@source/tests/pt_expt/test_multitask.py`:
- Around line 939-941: The test creates unused bindings (finetune_links_true and
n2) from get_finetune_rules and another unpack that Ruff flags; remove those
unused variables by changing the unpack to only capture what is used (e.g.,
assign just model_config_true = get_finetune_rules(...)[0] or unpack as
model_config_true, _ = get_finetune_rules(...)) and similarly drop or replace n2
with an underscore in the other unpack; update the lines that call
get_finetune_rules and any other tuple unpacking to avoid creating unused names.
In `@source/tests/pt_expt/test_training_ddp.py`:
- Around line 622-631: Wrap the mp.spawn call in test_ddp_single_task_trains
with a helper that enforces a 60s timeout and fails the test if any spawned
subprocess remains alive; specifically replace the direct mp.spawn(...)
invocation that launches _worker_single_task_train with a timeout-aware runner
(e.g., run_spawn_with_timeout(spawn_fn=_worker_single_task_train, nprocs=2,
args=(2, port, self.data_dir, result_dict), timeout=60)) so rendezvous or worker
hangs are terminated and the test asserts failure. Apply the same pattern to the
other mp.spawn sites noted (the ranges referenced in the comment) so every
training test uses the common 60s timeout helper.
- Around line 33-47: The test sets DEVICE too late so module-level imports
(e.g., importing get_trainer, get_finetune_rules, preprocess_shared_params,
normalize, update_deepmd_input) cause deepmd.pt_expt to read env.DEVICE before
spawned workers set it; to fix, ensure DEVICE is set to "cpu" before those
imports by either moving os.environ["DEVICE"]="cpu" to the top of the module
(before any deepmd.pt_expt imports) or by deferring these imports into the
worker entrypoints (the worker functions named like _worker_*) so the
environment is configured first.
- Around line 62-66: The current _find_free_port() closes the probe socket
before workers call init_process_group(), which allows another process to grab
the port; change _find_free_port to bind and listen but not close the socket
(e.g., set SO_REUSEADDR, bind(("127.0.0.1", 0)), listen()) and return the open
socket and its port number so the test can hold the socket open until after
init_process_group() completes; update callers to keep the returned socket alive
(and only close it after init_process_group() on all workers) to eliminate the
race.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4def67e7-d481-4a14-a469-e833eb2857b5
📒 Files selected for processing (4)
deepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pysource/tests/pt_expt/test_multitask.pysource/tests/pt_expt/test_training_ddp.py
| for model_key in self.model_keys: | ||
| data_requirement = self.loss[model_key].label_requirement | ||
| data_requirement += get_additional_data_requirement( | ||
| self.model[model_key] | ||
| ) | ||
| training_data[model_key].add_data_requirements(data_requirement) | ||
| if validation_data[model_key] is not None: | ||
| validation_data[model_key].add_data_requirements(data_requirement) | ||
| self.training_data[model_key] = training_data[model_key] | ||
| self.validation_data[model_key] = validation_data[model_key] |
There was a problem hiding this comment.
Handle validation_data=None in the multi-task setup.
Trainer.__init__ accepts validation_data: ... | None, but this branch indexes validation_data[model_key] unconditionally. A multi-task caller that disables validation will raise here before training starts.
Minimal fix
else:
self.training_data = {}
self.validation_data = {}
self.valid_numb_batch = {}
for model_key in self.model_keys:
data_requirement = self.loss[model_key].label_requirement
data_requirement += get_additional_data_requirement(
self.model[model_key]
)
training_data[model_key].add_data_requirements(data_requirement)
- if validation_data[model_key] is not None:
- validation_data[model_key].add_data_requirements(data_requirement)
+ branch_validation = (
+ None if validation_data is None else validation_data.get(model_key)
+ )
+ if branch_validation is not None:
+ branch_validation.add_data_requirements(data_requirement)
self.training_data[model_key] = training_data[model_key]
- self.validation_data[model_key] = validation_data[model_key]
+ self.validation_data[model_key] = branch_validation🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt_expt/train/training.py` around lines 531 - 540, The loop over
self.model_keys assumes validation_data is a mapping and indexes
validation_data[model_key] unconditionally; when the caller passes
validation_data=None this raises. Fix by checking whether validation_data is
None before accessing validation_data[model_key]: compute data_requirement via
self.loss[model_key] and get_additional_data_requirement(self.model[model_key]),
call training_data[model_key].add_data_requirements(...) as before, and only if
validation_data is not None call
validation_data[model_key].add_data_requirements(...) and set
self.validation_data[model_key] = validation_data[model_key]; otherwise set
self.validation_data[model_key] = None (or skip assignment) so the code handles
a global None validation_data safely.
There was a problem hiding this comment.
In multi-task mode, validation_data is always a dict populated with all model keys (see main.py:98-136 — each key is set to either a DeepmdDataSystem or None). The [model_key] indexing is safe.
There was a problem hiding this comment.
🧩 Analysis chain
Let me verify the claim by inspecting main.py.
🏁 Script executed:
#!/bin/bash
# Check the relevant section of main.py where validation_data is constructed
rg -n "validation_data" deepmd/pt_expt/entrypoints/main.py -A 3 -B 3
Length of output: 2093
@wanghan-iapcm, you're right. The main.py code confirms this clearly:
- Line 99:
validation_data = {}is initialized as a dict in the multi-task branch. - Lines 128 and 136: every
model_keyinmodel_params["model_dict"]is unconditionally assigned — either aDeepmdDataSystemorNone.
So validation_data[model_key] is always safe when Trainer.__init__ is reached in multi-task mode. My review comment was incorrect — apologies for the noise!
✏️ Learnings added
Learnt from: wanghan-iapcm
Repo: deepmodeling/deepmd-kit PR: 5397
File: deepmd/pt_expt/train/training.py:510-519
Timestamp: 2026-04-17T12:29:01.518Z
Learning: In `deepmd/pt_expt/entrypoints/main.py`, in multi-task mode, `validation_data` is always constructed as a `dict` with every model key present before being passed to `Trainer.__init__`. Each key maps to either a `DeepmdDataSystem` or `None` (see lines ~99-136). Therefore, `validation_data[model_key]` indexing inside `Trainer.__init__`'s multi-task branch is always safe and does not need a `None`-guard on the outer dict.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: wanghan-iapcm
Repo: deepmodeling/deepmd-kit PR: 5397
File: deepmd/pt_expt/train/training.py:1155-1168
Timestamp: 2026-04-17T12:28:09.479Z
Learning: In `deepmd/pt_expt/train/training.py`, during multi-task display steps inside `Trainer.run()`, rank 0 intentionally calls `get_data(is_train=True, task_key=_key)` and runs a forward pass through `_unwrapped` for non-selected tasks to populate `train_results`. This advances the training iterators for those tasks and is a deliberate design choice matching the PT backend's multi-task logging behavior: per-task loss metrics are displayed at every `disp_freq` step for all tasks, not just the one trained on in that step.
Cover the _compile_model DDP unwrap fix with single-task and multi-task tests that enable_compile=True under 2-rank gloo DDP.
Replace aot_eager+padding+manual recompile with symbolic make_fx + torch.compile(backend="inductor", dynamic=True). The compiled graph natively handles varying nframes/nloc/nall so the per-batch padding and runtime _recompile pass can be removed. Use a trace-time nframes of 7 (prime) and reshape with -1 in dpmodel (general_fitting, env_mat) to prevent PyTorch's symbolic tracer from unifying the batch dim with numb_fparam / numb_aparam / ntypes / dim_case_embd. Add TestCompiledVaryingNframesWithParams covering collisions with fparam/aparam, and TestCompileCaseEmbdVaryingNframes covering dim_case_embd > 0 with runtime nframes matching the embed dim.
…modeling#5393 Add silut/custom_silu support to _torch_activation using native torch ops (sigmoid, tanh, where) so the custom silu stays traceable by make_fx / torch.export. Cross-backend consistency tests cover multiple thresholds across the silu/tanh branches, and a pt_expt unit file exercises default/custom threshold, gradient flow, make_fx, and torch.export. Also port DescrptBlockRepformers accessor tests (get_rcut_smth, get_env_protection). The underlying accessor methods already exist on this branch; these tests guard against regressions.
Extend the compiled-vs-uncompiled assertions in TestCompiledConsistency (single-task) and _check_compile_correctness (multi-task) to also cover ``atom_energy`` and the reduced ``virial``. Atomic virial is intentionally not exercised because training never sets ``do_atomic_virial=True``.
…deling#5393 Adds the remaining tests from PR deepmodeling#5393 that were not yet on this branch: ``test_training_loop_compiled_silu`` (silu activation under torch.compile) and ``TestCompiledVaryingNatoms`` (compiled training across systems with different atom counts). Also drops a stray unused ``threshold`` variable in ``test_silut_below_threshold_is_silu`` to match the upstream PR.
Replace the two finite-loss smoke tests with a single test that builds both trainers, syncs weights, and per-step asserts identical predictions, loss, and per-parameter gradients (second-order through the force loss). Also add a silu full-model consistency test and write virial.npy in the small synthetic system so the virial passthrough is exercised on every step. Factor the prediction/grad comparison loops into shared helpers.
Parametrize TestCompiledVaryingNatoms over se_e2_a, DPA2 and DPA3 with strict atol=rtol=1e-10 on float64 (machine epsilon). DPA1 (se_atten) is intentionally omitted: its compiled path is intermittently incorrect (~20% of compiles produce grad diffs up to 0.67 at the first embedding layer), and including it would have required masking the bug with a loose tolerance.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
source/tests/pt_expt/test_training.py (1)
291-343:⚠️ Potential issue | 🟡 MinorAdd a timeout to the new training-heavy tests.
These cases run repeated compile/train loops but don't declare a timeout, so CI will sit on them indefinitely if one starts hanging.
As per coding guidelines,
**/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days.Also applies to: 356-491, 789-1171
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training.py` around lines 291 - 343, The new long-running test TestCompiledDynamicShapes.test_compiled_handles_varying_nall lacks a CI timeout; add a 60-second timeout by decorating the test method with `@pytest.mark.timeout`(60) and importing pytest at top of the file (or use the project’s canonical test-timeout decorator if different), so the test will fail fast if it hangs; apply the same 60s timeout change to the other training-heavy tests referenced (the ranges 356-491 and 789-1171) by decorating their test methods (or their TestCase classes) similarly.
♻️ Duplicate comments (9)
deepmd/pt_expt/train/training.py (4)
812-818:⚠️ Potential issue | 🟠 MajorKeep the same
data_stat_protectwhen reapplying sharing after load.The initial
share_params()call uses the validated branch value, but the resume path drops it. Any non-default configuration will restore one state and then immediately re-share stats with different protection semantics.Minimal fix
if shared_links is not None: # Re-apply sharing after loading checkpoint self._unwrapped.share_params( shared_links, resume=True, - model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ), + data_stat_protect=_data_stat_protect[0], )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 812 - 818, When reapplying sharing in the resume path, include the same validated data_stat_protect value so protection semantics aren’t lost; update the call to self._unwrapped.share_params(...) inside the resume branch to pass the data_stat_protect argument (the same variable used in the original share_params call) along with shared_links, resume=True, and the model_key_prob_map constructed from self.model_keys and self.model_prob.
499-508:⚠️ Potential issue | 🟠 MajorHandle
validation_data=Nonein the multi-task setup.
__init__()acceptsvalidation_data: ... | None, but Line 505 indexes it unconditionally. A multi-task caller that disables validation will fail before training starts.Minimal fix
for model_key in self.model_keys: data_requirement = self.loss[model_key].label_requirement data_requirement += get_additional_data_requirement( self.model[model_key] ) training_data[model_key].add_data_requirements(data_requirement) - if validation_data[model_key] is not None: - validation_data[model_key].add_data_requirements(data_requirement) + branch_validation = ( + None if validation_data is None else validation_data.get(model_key) + ) + if branch_validation is not None: + branch_validation.add_data_requirements(data_requirement) self.training_data[model_key] = training_data[model_key] - self.validation_data[model_key] = validation_data[model_key] + self.validation_data[model_key] = branch_validation🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 499 - 508, The code unconditionally indexes validation_data[model_key] which fails if the caller passed validation_data=None; change the loop to guard that access—either replace uses of validation_data[model_key] with a safe lookup (e.g., if validation_data is not None and validation_data.get(model_key) is not None then call add_data_requirements and assign self.validation_data[model_key] = validation_data[model_key], otherwise set self.validation_data[model_key] = None), or pre-normalize validation_data to a dict of Nones for all self.model_keys before the loop so the existing logic (training_data, add_data_requirements, and assigning self.validation_data) works without raising when validation_data was None.
1139-1152:⚠️ Potential issue | 🟠 MajorDon't consume training batches just to populate log output.
Line 1143 pulls extra training batches for the inactive heads on rank 0. That changes which samples those tasks actually train on, and in distributed runs only rank 0 advances those iterators further.
Please log from validation data or from a cached/non-advancing sample instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 1139 - 1152, The loop over self.model_keys currently calls self.get_data(is_train=True, task_key=_key) and self._unwrapped for inactive heads, which consumes training iterators and advances only rank 0; instead, log using non-advancing data: replace those get_data calls with a non-advancing source (e.g., use validation data via get_data(is_train=False, task_key=_key) or a cached/sample snapshot you capture earlier), and call self._unwrapped with that non-advancing label/input; keep the rest of the collection logic (train_results[_key] = ...) the same so logging no longer mutates training iterators for self.model_keys and task_key.
606-610:⚠️ Potential issue | 🟡 MinorAdd
strict=Trueto bothzip()calls.Both sites silently truncate if
self.model_keysandself.model_probever drift, and Ruff will flag them either way.As per coding guidelines, `**/*.py`: Install linter and run `ruff check .` before committing changes or the CI will fail.Minimal fix
- model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ),- model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ),Also applies to: 814-817
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 606 - 610, The zip calls that build the model_key_prob_map (pairing self.model_keys and self.model_prob) must be made strict to avoid silent truncation; update the dict(zip(...)) in the self.wrapper.share_params call (model_key_prob_map argument) to use zip(self.model_keys, self.model_prob, strict=True), and do the same for the other zip later in the file that also pairs self.model_keys with self.model_prob so both mappings enforce equal lengths.source/tests/pt_expt/test_multitask.py (2)
1951-1953:⚠️ Potential issue | 🟡 MinorRename the unused
n2loop variable.Ruff will flag
n2as unused here.As per coding guidelines, `**/*.py`: Install linter and run `ruff check .` before committing changes or the CI will fail.Minimal fix
- for (n1, p1), (n2, p2) in zip( + for (n1, p1), (_, p2) in zip( mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_multitask.py` around lines 1951 - 1953, The loop declares unused variable n2 causing linter errors; rename n2 to a throwaway name (e.g., _n2 or _) in the for loop that iterates over mt_desc.named_parameters() and mt_desc_2.named_parameters() so only p2 is treated as used (i.e., for (n1, p1), (_n2, p2) in zip(..., strict=True)). Ensure the change touches the loop in test_multitask.py where mt_desc and mt_desc_2 named_parameters() are zipped so Ruff no longer flags an unused variable.
939-941:⚠️ Potential issue | 🟡 MinorDrop the unused
finetune_links_truebinding.Ruff will flag this local as unused.
As per coding guidelines, `**/*.py`: Install linter and run `ruff check .` before committing changes or the CI will fail.Minimal fix
- model_config_true, finetune_links_true = get_finetune_rules( + model_config_true, _ = get_finetune_rules( ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_multitask.py` around lines 939 - 941, The call to get_finetune_rules returns two values but the second (finetune_links_true) is unused; remove the unused binding by either unpacking only the first return (e.g., assign the single result of get_finetune_rules to model_config_true) or use a throwaway name like _ for the second value so Ruff no longer flags finetune_links_true as unused; update the invocation of get_finetune_rules where model_config_true, finetune_links_true is currently used.source/tests/pt_expt/test_training_ddp.py (3)
63-67:⚠️ Potential issue | 🟠 MajorThe free-port probe is still racy for DDP rendezvous.
This helper releases the socket before the workers call
init_process_group(), so another parallel test can claim the same port and make the suite flaky. Please switch to a rendezvous that stays allocated for the whole run, e.g.file://...or a pre-createdTCPStore.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 63 - 67, The helper _find_free_port() is racy because it closes the probe socket before workers call init_process_group(), allowing races on the chosen port; replace this pattern by using a rendezvous that stays allocated for the whole test run (for example switch the rendezvous URL from tcp://<port> to a file://... path or create a persistent torch.distributed.TCPStore ahead of spawning workers) and update the test harness that builds the init_process_group() args to use the file-based rendezvous or the pre-created TCPStore so the address/resource remains reserved for the duration of the DDP test.
17-18:⚠️ Potential issue | 🟠 MajorSet
DEVICE=cpubefore importingdeepmd.pt_expt.The worker-level override is too late here.
mp.spawnstarts a fresh interpreter that imports this module before entering_worker_*, so the module-scopedeepmd.pt_exptimports can lock in the parent/default device first. On CUDA-capable runners that makes this “CPU-only” suite depend on the ambient device instead of the explicit test setting.Suggested fix
import os import shutil import socket import tempfile import unittest @@ +os.environ["DEVICE"] = "cpu" + from deepmd.pt_expt.entrypoints.main import ( get_trainer, )Also applies to: 34-48
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 17 - 18, Set the environment variable DEVICE=cpu before importing deepmd.pt_expt in the module scope of the test file so spawned worker interpreters don't inherit the ambient CUDA device; modify the top of source/tests/pt_expt/test_training_ddp.py to export DEVICE=cpu (e.g., via os.environ["DEVICE"]="cpu") prior to any import that pulls in deepmd.pt_expt or before any use of mp.spawn/_worker_* so the child processes import with the correct CPU-only setting.
623-632:⚠️ Potential issue | 🟠 MajorThese spawned training tests still need a hard 60s timeout.
mp.spawn(..., join=True)can hang indefinitely on a broken rendezvous or stuck worker, and this pattern repeats through the rest of the module. Please wrap the spawn call in a bounded subprocess/helper so each training case fails fast after 60 seconds.As per coding guidelines "Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training_ddp.py` around lines 623 - 632, The test test_ddp_single_task_trains uses mp.spawn to run _worker_single_task_train and can hang indefinitely; wrap the mp.spawn call in a bounded timeout helper (e.g., run it in a short-lived subprocess or use a helper that starts the spawn in a separate Process and joins with a 60-second timeout) so that if the spawn does not complete within 60 seconds the test forcefully terminates the worker(s) and fails; ensure the same pattern (60s hard timeout) is applied wherever mp.spawn is called in this module.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 603-605: Replace the assert check on _data_stat_protect with an
explicit exception to avoid being skipped under python -O: validate that
np.allclose(_data_stat_protect, _data_stat_protect[0]) and if it fails raise a
ValueError (or ConfigError) with a clear message that includes the mismatched
values and mentions the model key 'data_stat_protect' and multitask branches;
update the code around the existing assert (the validation that currently reads
"Model key 'data_stat_protect' must be the same in each branch when multitask!")
to perform this explicit check and raise the informative exception.
In `@source/tests/consistent/test_activation.py`:
- Around line 110-117: The test functions (e.g.,
test_pt_expt_consistent_with_ref) call .detach().numpy() on tensors created on
PT_EXPT_DEVICE which may be a GPU device; move the tensor to CPU before
converting to NumPy by inserting .cpu() before .numpy() (apply the same change
to the second occurrence around lines 148-159 and any similar tests using
PT_EXPT_DEVICE) so calls become .detach().cpu().numpy() to avoid RuntimeError on
non-CPU devices.
---
Outside diff comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 291-343: The new long-running test
TestCompiledDynamicShapes.test_compiled_handles_varying_nall lacks a CI timeout;
add a 60-second timeout by decorating the test method with
`@pytest.mark.timeout`(60) and importing pytest at top of the file (or use the
project’s canonical test-timeout decorator if different), so the test will fail
fast if it hangs; apply the same 60s timeout change to the other training-heavy
tests referenced (the ranges 356-491 and 789-1171) by decorating their test
methods (or their TestCase classes) similarly.
---
Duplicate comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 812-818: When reapplying sharing in the resume path, include the
same validated data_stat_protect value so protection semantics aren’t lost;
update the call to self._unwrapped.share_params(...) inside the resume branch to
pass the data_stat_protect argument (the same variable used in the original
share_params call) along with shared_links, resume=True, and the
model_key_prob_map constructed from self.model_keys and self.model_prob.
- Around line 499-508: The code unconditionally indexes
validation_data[model_key] which fails if the caller passed
validation_data=None; change the loop to guard that access—either replace uses
of validation_data[model_key] with a safe lookup (e.g., if validation_data is
not None and validation_data.get(model_key) is not None then call
add_data_requirements and assign self.validation_data[model_key] =
validation_data[model_key], otherwise set self.validation_data[model_key] =
None), or pre-normalize validation_data to a dict of Nones for all
self.model_keys before the loop so the existing logic (training_data,
add_data_requirements, and assigning self.validation_data) works without raising
when validation_data was None.
- Around line 1139-1152: The loop over self.model_keys currently calls
self.get_data(is_train=True, task_key=_key) and self._unwrapped for inactive
heads, which consumes training iterators and advances only rank 0; instead, log
using non-advancing data: replace those get_data calls with a non-advancing
source (e.g., use validation data via get_data(is_train=False, task_key=_key) or
a cached/sample snapshot you capture earlier), and call self._unwrapped with
that non-advancing label/input; keep the rest of the collection logic
(train_results[_key] = ...) the same so logging no longer mutates training
iterators for self.model_keys and task_key.
- Around line 606-610: The zip calls that build the model_key_prob_map (pairing
self.model_keys and self.model_prob) must be made strict to avoid silent
truncation; update the dict(zip(...)) in the self.wrapper.share_params call
(model_key_prob_map argument) to use zip(self.model_keys, self.model_prob,
strict=True), and do the same for the other zip later in the file that also
pairs self.model_keys with self.model_prob so both mappings enforce equal
lengths.
In `@source/tests/pt_expt/test_multitask.py`:
- Around line 1951-1953: The loop declares unused variable n2 causing linter
errors; rename n2 to a throwaway name (e.g., _n2 or _) in the for loop that
iterates over mt_desc.named_parameters() and mt_desc_2.named_parameters() so
only p2 is treated as used (i.e., for (n1, p1), (_n2, p2) in zip(...,
strict=True)). Ensure the change touches the loop in test_multitask.py where
mt_desc and mt_desc_2 named_parameters() are zipped so Ruff no longer flags an
unused variable.
- Around line 939-941: The call to get_finetune_rules returns two values but the
second (finetune_links_true) is unused; remove the unused binding by either
unpacking only the first return (e.g., assign the single result of
get_finetune_rules to model_config_true) or use a throwaway name like _ for the
second value so Ruff no longer flags finetune_links_true as unused; update the
invocation of get_finetune_rules where model_config_true, finetune_links_true is
currently used.
In `@source/tests/pt_expt/test_training_ddp.py`:
- Around line 63-67: The helper _find_free_port() is racy because it closes the
probe socket before workers call init_process_group(), allowing races on the
chosen port; replace this pattern by using a rendezvous that stays allocated for
the whole test run (for example switch the rendezvous URL from tcp://<port> to a
file://... path or create a persistent torch.distributed.TCPStore ahead of
spawning workers) and update the test harness that builds the
init_process_group() args to use the file-based rendezvous or the pre-created
TCPStore so the address/resource remains reserved for the duration of the DDP
test.
- Around line 17-18: Set the environment variable DEVICE=cpu before importing
deepmd.pt_expt in the module scope of the test file so spawned worker
interpreters don't inherit the ambient CUDA device; modify the top of
source/tests/pt_expt/test_training_ddp.py to export DEVICE=cpu (e.g., via
os.environ["DEVICE"]="cpu") prior to any import that pulls in deepmd.pt_expt or
before any use of mp.spawn/_worker_* so the child processes import with the
correct CPU-only setting.
- Around line 623-632: The test test_ddp_single_task_trains uses mp.spawn to run
_worker_single_task_train and can hang indefinitely; wrap the mp.spawn call in a
bounded timeout helper (e.g., run it in a short-lived subprocess or use a helper
that starts the spawn in a separate Process and joins with a 60-second timeout)
so that if the spawn does not complete within 60 seconds the test forcefully
terminates the worker(s) and fails; ensure the same pattern (60s hard timeout)
is applied wherever mp.spawn is called in this module.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 083dd26d-6c3a-41e9-a2dc-22f6405f9ba9
📒 Files selected for processing (10)
deepmd/dpmodel/fitting/general_fitting.pydeepmd/dpmodel/utils/env_mat.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/utils/network.pysource/tests/common/dpmodel/test_descriptor_dpa2.pysource/tests/consistent/test_activation.pysource/tests/pt_expt/test_multitask.pysource/tests/pt_expt/test_training.pysource/tests/pt_expt/test_training_ddp.pysource/tests/pt_expt/utils/test_activation.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/fitting/general_fitting.py
| @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") | ||
| def test_pt_expt_consistent_with_ref(self) -> None: | ||
| if INSTALLED_PT_EXPT: | ||
| x = torch.tensor( | ||
| self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE | ||
| ) | ||
| test = _torch_activation(x, self.activation).detach().numpy() | ||
| np.testing.assert_allclose(self.ref, test, atol=1e-10) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's locate the file and read the relevant sections
wc -l source/tests/consistent/test_activation.pyRepository: deepmodeling/deepmd-kit
Length of output: 113
🏁 Script executed:
# Read the file around lines 110-117
sed -n '105,120p' source/tests/consistent/test_activation.pyRepository: deepmodeling/deepmd-kit
Length of output: 720
🏁 Script executed:
# Read the file around lines 152-159
sed -n '147,165p' source/tests/consistent/test_activation.pyRepository: deepmodeling/deepmd-kit
Length of output: 700
🏁 Script executed:
# Find where PT_EXPT_DEVICE is defined
rg "PT_EXPT_DEVICE" --context 3Repository: deepmodeling/deepmd-kit
Length of output: 11151
🏁 Script executed:
# Find the definition of DEVICE in deepmd.pt_expt.utils.env
find . -path "*/deepmd/pt_expt/utils/env*" -type fRepository: deepmodeling/deepmd-kit
Length of output: 96
🏁 Script executed:
# Read the env module to understand what DEVICE is
cat deepmd/pt_expt/utils/env.py | head -50Repository: deepmodeling/deepmd-kit
Length of output: 1569
🏁 Script executed:
# Search for the actual definition of DEVICE
rg "^DEVICE" deepmd/pt_expt/utils/env.py -A 3Repository: deepmodeling/deepmd-kit
Length of output: 49
🏁 Script executed:
# Read more of the env.py file to find DEVICE definition
cat deepmd/pt_expt/utils/env.py | tail -50Repository: deepmodeling/deepmd-kit
Length of output: 1754
🏁 Script executed:
# Search more broadly for DEVICE
rg "DEVICE\s*=" deepmd/pt_expt/utils/env.pyRepository: deepmodeling/deepmd-kit
Length of output: 147
🏁 Script executed:
# Get more context around the DEVICE definition
rg "DEVICE" deepmd/pt_expt/utils/env.py -B 5 -A 2Repository: deepmodeling/deepmd-kit
Length of output: 559
🏁 Script executed:
# Let's verify both test method implementations fully
sed -n '108,125p' source/tests/consistent/test_activation.pyRepository: deepmodeling/deepmd-kit
Length of output: 756
🏁 Script executed:
# Check the second occurrence around line 152-159
sed -n '148,165p' source/tests/consistent/test_activation.pyRepository: deepmodeling/deepmd-kit
Length of output: 665
🏁 Script executed:
# Let's look at how other files in the codebase handle this - check how they handle detach().numpy() with PT_EXPT_DEVICE
rg "detach\(\)\.cpu\(\)\.numpy\(\)" source/tests/consistent/fitting/ -B 2Repository: deepmodeling/deepmd-kit
Length of output: 49
🏁 Script executed:
# Search for detach().cpu().numpy() pattern across the whole codebase
rg "detach\(\)\.cpu\(\)\.numpy\(\)" -B 2Repository: deepmodeling/deepmd-kit
Length of output: 50380
🏁 Script executed:
# Look at test_polar.py to see the pattern mentioned in earlier search
sed -n '75,95p' source/tests/consistent/fitting/test_polar.pyRepository: deepmodeling/deepmd-kit
Length of output: 553
Add .cpu() before calling numpy() on tensors created on PT_EXPT_DEVICE.
PT_EXPT_DEVICE may resolve to CUDA/MPS if the system has GPUs available. In such cases, calling .detach().numpy() directly raises a RuntimeError. The tensor must be moved to CPU first.
Suggested fix
- test = _torch_activation(x, self.activation).detach().numpy()
+ test = _torch_activation(x, self.activation).detach().cpu().numpy()This applies to both occurrences (lines 110-117 and 148-159).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") | |
| def test_pt_expt_consistent_with_ref(self) -> None: | |
| if INSTALLED_PT_EXPT: | |
| x = torch.tensor( | |
| self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE | |
| ) | |
| test = _torch_activation(x, self.activation).detach().numpy() | |
| np.testing.assert_allclose(self.ref, test, atol=1e-10) | |
| `@unittest.skipUnless`(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") | |
| def test_pt_expt_consistent_with_ref(self) -> None: | |
| if INSTALLED_PT_EXPT: | |
| x = torch.tensor( | |
| self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE | |
| ) | |
| test = _torch_activation(x, self.activation).detach().cpu().numpy() | |
| np.testing.assert_allclose(self.ref, test, atol=1e-10) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/consistent/test_activation.py` around lines 110 - 117, The test
functions (e.g., test_pt_expt_consistent_with_ref) call .detach().numpy() on
tensors created on PT_EXPT_DEVICE which may be a GPU device; move the tensor to
CPU before converting to NumPy by inserting .cpu() before .numpy() (apply the
same change to the second occurrence around lines 148-159 and any similar tests
using PT_EXPT_DEVICE) so calls become .detach().cpu().numpy() to avoid
RuntimeError on non-CPU devices.
There was a problem hiding this comment.
Good catch on the .cpu() call. However, test_activation.py is in source/tests/consistent/ which is outside the scope of this PR (this PR only touches pt_expt code). The consistent tests have their own device handling. Will address if it becomes an issue.
There was a problem hiding this comment.
@wanghan-iapcm, understood — thanks for the context. Since the consistent tests manage their own device abstractions, it makes sense to defer this. Feel free to revisit if a GPU CI failure surfaces for this path.
🧠 Learnings used
Learnt from: wanghan-iapcm
Repo: deepmodeling/deepmd-kit PR: 5370
File: source/lmp/tests/test_lammps_spin_nopbc_pt2.py:22-23
Timestamp: 2026-04-04T17:03:29.434Z
Learning: In deepmd-kit, all LAMMPS Python test files under `source/lmp/tests/` (e.g., `test_lammps_spin_pt.py`, `test_lammps_pt2.py`, `test_lammps_spin_nopbc_pt2.py`) intentionally use module-level fixed Path constants (`data.lmp`, `md.out`) managed by `setup_module`/`teardown_module`. This is the established repo-wide convention; do not flag it as a file-collision or parallel-run risk, as changing it would break consistency with the rest of the test suite.
Learnt from: wanghan-iapcm
Repo: deepmodeling/deepmd-kit PR: 5397
File: deepmd/pt_expt/train/training.py:1155-1168
Timestamp: 2026-04-17T12:28:09.479Z
Learning: In `deepmd/pt_expt/train/training.py`, during multi-task display steps inside `Trainer.run()`, rank 0 intentionally calls `get_data(is_train=True, task_key=_key)` and runs a forward pass through `_unwrapped` for non-selected tasks to populate `train_results`. This advances the training iterators for those tasks and is a deliberate design choice matching the PT backend's multi-task logging behavior: per-task loss metrics are displayed at every `disp_freq` step for all tasks, not just the one trained on in that step.
Learnt from: wanghan-iapcm
Repo: deepmodeling/deepmd-kit PR: 5391
File: source/tests/pt_expt/infer/test_deep_eval.py:1667-1691
Timestamp: 2026-04-11T08:01:08.364Z
Learning: In deepmd-kit, `get_dp_atomic_model()` is a model-level API exposed on the full model hierarchy (e.g., `make_model` CM class, `FrozenModel`, `SpinModel`) and is not available on atomic models (e.g., `DPZBLLinearEnergyAtomicModel`, `DPAtomicModel`). When writing/validating tests for ZBL or other atomic-model behavior, do not call `get_dp_atomic_model()` on atomic model instances; instead, use `isinstance` checks (or equivalent type assertions) on the atomic model.
There was a problem hiding this comment.
Pull request overview
This PR adds multi-task training support to the pt_expt backend, including shared-parameter/stat handling across tasks, distributed (DDP) execution, and improved torch.compile support for dynamic shapes and higher-order gradients.
Changes:
- Introduces multi-task model wiring (
ModelWrapper,Trainer, config preprocessing) with optional parameter sharing and per-task case embeddings. - Adds DDP support (rank-0 gating, stat broadcasting,
find_unused_parameters) and switches the compile path to symbolic tracing +torch.compile(dynamic=True). - Extends/aligns stats handling and activation support (
silut/custom_silu), plus substantial new test coverage for multi-task, compile correctness, and stat merging.
Reviewed changes
Copilot reviewed 40 out of 40 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt_expt/utils/test_activation.py | Adds pt_expt-specific tests for silut activation correctness/grad/export/FX tracing. |
| source/tests/pt_expt/test_training.py | Expands compile correctness, dynamic-shape, grad-consistency, and varying-shape training tests. |
| source/tests/pt_expt/test_finetune.py | Updates finetune test paths to multi-task model wrapper structure (["Default"]). |
| source/tests/pt_expt/test_change_bias.py | Updates bias access for wrapped multi-task model (["Default"]). |
| source/tests/pt_expt/fitting/test_fitting_stat.py | Adds multi-task shared fitting-stat tests and probability-weighted stat expectations. |
| source/tests/pt_expt/descriptor/test_se_t_tebd.py | Adds share_params behavior tests for se_t_tebd. |
| source/tests/pt_expt/descriptor/test_se_t.py | Adds share_params tests for se_t. |
| source/tests/pt_expt/descriptor/test_se_r.py | Adds share_params tests for se_r. |
| source/tests/pt_expt/descriptor/test_se_atten_v2.py | Adds share_params tests for se_atten_v2. |
| source/tests/pt_expt/descriptor/test_hybrid.py | Adds recursive share_params tests for hybrid descriptors. |
| source/tests/pt_expt/descriptor/test_dpa3.py | Adds share_params tests for DPA3 descriptor sharing levels. |
| source/tests/pt_expt/descriptor/test_dpa1.py | Adds share_params tests for DPA1 descriptor sharing levels. |
| source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py | Adds extensive unit tests for probability-weighted descriptor-stat merging. |
| source/tests/pt/test_fitting_stat.py | Aligns PT test nbatch values with actual frame counts (80). |
| source/tests/consistent/test_activation.py | Adds cross-backend consistency checks for pt_expt activation and silut variants. |
| source/tests/common/dpmodel/test_descriptor_dpa2.py | Adds accessor tests for new Repformers getters. |
| deepmd/pt_expt/utils/network.py | Adds silut/custom_silu activation implementation to pt_expt. |
| deepmd/pt_expt/utils/multi_task.py | Introduces shared-config preprocessing to build per-task configs and sharing links. |
| deepmd/pt_expt/utils/finetune.py | Extends finetune rules to cover multi-task → multi-task and single-task → multi-task flows. |
| deepmd/pt_expt/train/wrapper.py | Generalizes wrapper to handle multi-model/loss dicts and per-task forward/loss evaluation. |
| deepmd/pt_expt/train/training.py | Implements multi-task training loop, DDP handling, dynamic-shape compile, and detach-node stripping for 2nd-order grads. |
| deepmd/pt_expt/fitting/invar_fitting.py | Adds share_params with probability-weighted fparam/aparam stat merging and selective sharing exclusions. |
| deepmd/pt_expt/fitting/ener_fitting.py | Hooks EnergyFittingNet.share_params to the InvarFitting implementation. |
| deepmd/pt_expt/entrypoints/main.py | Adds multi-task preprocessing, DDP init/teardown, and multi-task --head support for freeze. |
| deepmd/pt_expt/descriptor/se_t_tebd.py | Adds descriptor-level share_params with stat merging via merge_env_stat. |
| deepmd/pt_expt/descriptor/se_t.py | Adds descriptor-level share_params with stat merging via merge_env_stat. |
| deepmd/pt_expt/descriptor/se_r.py | Adds descriptor-level share_params with stat merging via merge_env_stat. |
| deepmd/pt_expt/descriptor/se_e2_a.py | Adds descriptor-level share_params with stat merging via merge_env_stat. |
| deepmd/pt_expt/descriptor/se_atten_v2.py | Delegates share_params to DPA1 logic for consistent sharing behavior. |
| deepmd/pt_expt/descriptor/hybrid.py | Adds recursive share_params for hybrid descriptors. |
| deepmd/pt_expt/descriptor/dpa3.py | Adds DPA3 share_params with sharing levels and stat merging. |
| deepmd/pt_expt/descriptor/dpa2.py | Adds DPA2 share_params including three-body stat merge/sharing when enabled. |
| deepmd/pt_expt/descriptor/dpa1.py | Adds DPA1 share_params with sharing levels and stat merging. |
| deepmd/pt/model/task/fitting.py | Removes explicit fparam/aparam dim checks in PT path (reshape-only behavior). |
| deepmd/dpmodel/utils/env_mat_stat.py | Adds probability-weighted merge_env_stat helper used by pt_expt sharing. |
| deepmd/dpmodel/utils/env_mat.py | Adjusts reshaping logic to avoid symbolic-dim collisions during tracing/compile. |
| deepmd/dpmodel/fitting/general_fitting.py | Stores raw param stats and reshapes with -1 to avoid symbolic specialization; adds accessor. |
| deepmd/dpmodel/descriptor/repformers.py | Adds get_rcut_smth() and get_env_protection() accessors. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import torch.distributed as dist | ||
|
|
||
| if os.environ.get("LOCAL_RANK") is not None: | ||
| dist.init_process_group(backend="cuda:nccl,cpu:gloo") |
There was a problem hiding this comment.
"cuda:nccl,cpu:gloo" is valid PyTorch 2.0+ syntax for device-specific backend mapping. See PyTorch docs — it routes CUDA tensors through NCCL and CPU tensors through Gloo in a single process group.
There was a problem hiding this comment.
"cuda:nccl,cpu:gloo" is valid PyTorch 2.0+ syntax for device-specific backend mapping. See PyTorch docs.
| model_index = dp_random.choice( | ||
| np.arange(self.num_model, dtype=np.int_), | ||
| p=self.model_prob, | ||
| ) |
There was a problem hiding this comment.
This mirrors the PT backend exactly (deepmd/pt/train/training.py:1079). dp_random uses a shared np.random.RandomState seeded identically across ranks, so all ranks produce the same task sequence.
| elif item_params.get("type", "") == "hybrid": | ||
| for ii, hybrid_item in enumerate(item_params["list"]): | ||
| if isinstance(hybrid_item, str): | ||
| replace_one_item( | ||
| model_params_item[item_key]["list"], | ||
| item_key, | ||
| hybrid_item, | ||
| suffix=f"_hybrid_{ii}", | ||
| index=ii, | ||
| ) |
There was a problem hiding this comment.
item_params is either a str (handled by line 84) or a dict (inline definition), enforced by config schema validation in deepmd.utils.argcheck.normalize() which runs before preprocess_shared_params.
| f"which is not consistent with {self.numb_fparam}." | ||
| ) | ||
| f"input fparam: cannot reshape {fparam.shape} " | ||
| f"into ({nf}, {self.numb_fparam})." |
There was a problem hiding this comment.
This was valid at the time of review — the code previously used (-1, self.numb_fparam). Fixed in 80c714c which changed the reshape to (nf, self.numb_fparam), now matching the error message.
| @@ -804,11 +799,6 @@ def _forward_common( | |||
| assert aparam is not None, "aparam should not be None" | |||
| assert self.aparam_avg is not None | |||
| assert self.aparam_inv_std is not None | |||
| if aparam.shape[-1] != self.numb_aparam: | |||
| raise ValueError( | |||
| f"get an input aparam of dim {aparam.shape[-1]}, ", | |||
| f"which is not consistent with {self.numb_aparam}.", | |||
| ) | |||
| aparam = aparam.view([nf, -1, self.numb_aparam]) | |||
| nb, nloc, _ = aparam.shape | |||
There was a problem hiding this comment.
Good catch. The removed checks were buggy (non-f-string + tuple ValueError), but the user-friendly error is worth keeping. Fixed in 6ae50db — added try/except wrapping to match the dpmodel pattern.
Enable use_three_body=True on the DPA2 varying-natoms test so the compiled three-body neighbor path is also covered. three_body_rcut=3.0 matches the repformer rcut and is large enough to find neighbors in the 6-atom small system (~2.75Å nearest-neighbor distance).
…ckward decomp Revert (-1, nloc*nnei, ...) reshapes back to (nf, -1, ...) in env_mat.py and general_fitting.py. The -1-for-nf pattern breaks zero-atom systems: numpy cannot infer -1 when other dims multiply to zero (0/0), and torch.export shape assertions hit Mod(0,0). Using nf is safe because _TRACE_NFRAMES=7 already prevents symbolic-tracer specialisation during training compile. Add silu_backward decomposition table to make_fx in training.py so inductor can compile second-order gradients through silu without requiring a fused higher-order derivative kernel.
Replace assert with if/raise ValueError for user-facing config validation (data_stat_protect, finetune branch/head checks). Wrap train() in try/finally for destroy_process_group cleanup. Add parents=True, exist_ok=True to stat_file mkdir. Add strict=True to zip() calls. Fix minor test issues.
Match the dpmodel try/except pattern so shape mismatches produce a clear error instead of a raw RuntimeError from torch.view.
6ae50db to
c2efbf1
Compare
Summary
merge_env_stat), shared fitting viashare_params, andcase_embdper-task embeddingfind_unused_parameterstorch.compilefor multi-task withbackend="inductor",dynamic=True, and symbolic tracing (make_fx(tracing_mode="symbolic")); includessilu_backwarddecomposition for second-order gradient compatibilitysilutactivation variant andDescrptBlockRepformersaccessors (ported from perf(pt_expt): use inductor+dynamic for torch.compile training #5393)Known limitations
num_epoch_dict: onlynumb_steps+model_prob; epoch-based scheduling deferredEnergyFittingNethasshare_params; other fitting types (DOS, dipole, polar, property) need the same overrideshare_fitting+ single-task finetune is incompatible (nodim_case_embdin pretrained)use_dynamic_sel: truecannot compile (symbolic tracer fails on data-dependentint()inget_graph_index)Test plan
test_multitask.py): training, freeze, finetune, compile, shared fitting, DPA1/DPA2/DPA3/SeAtest_training.py): compile correctness, dynamic shapes, silu compiletest_training_ddp.py): single-task + multi-task with compiletest_descrpt_stat_merge.py)test_fitting_stat.py)test_activation.py): silut export + compile compatibility